import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch_geometric.utils import add_remaining_self_loops, remove_self_loops, softmax, add_self_loops
from torch_scatter import scatter, scatter_add
from torch_geometric.nn.conv import MessagePassing, GATConv
from torch.nn.parameter import Parameter
from torch_geometric.nn.inits import glorot, zeros

from script.manifolds.pseudohyperboloid_sr import PseudoHyperboloidSR


class PseudoGraphConvolution(nn.Module):
    """
    Pseudo hyperboloid graph convolution layer, based on hgcn
    """

    def __init__(self, time_dim, space_dim, beta, in_features, out_features, dropout=0.6, act=F.leaky_relu,
                 use_bias=True):
        super(PseudoGraphConvolution, self).__init__()
        self.manifold = PseudoHyperboloidSR(time_dim, space_dim, beta)
        self.beta = beta
        self.linear = PseudoLinear(self.manifold, in_features, out_features, self.beta, dropout=dropout, use_bias=use_bias)
        self.agg = PseudoAgg(self.manifold, self.beta)
        self.hyp_act = PseudoAct(self.manifold, beta, act)

    def forward(self, x, edge_index):
        h = self.linear.forward(x)
        h = self.agg.forward(h, edge_index)
        h = self.hyp_act.forward(h)
        return h

class PseudoLinear(nn.Module):
    """
    Pseudo hyperboloid linear layer.
    """

    def __init__(self, manifold, in_features, out_features, beta, dropout=0.6, use_bias=True):
        super(PseudoLinear, self).__init__()
        self.manifold = manifold
        self.in_features = in_features
        self.out_features = out_features
        self.beta = beta
        self.dropout = dropout
        self.use_bias = use_bias
        self.bias = nn.Parameter(torch.Tensor(out_features), requires_grad=True)
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features), requires_grad=True)
        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        zeros(self.bias)

    def forward(self, x):
        drop_weight = F.dropout(self.weight, p=self.dropout, training=self.training)
        mv = self.manifold.mobius_matvec(drop_weight, x, self.beta)
        res = mv
        if self.use_bias:
            bias = self.manifold.proj_tan_0(self.bias.view(1, -1), self.beta)
            hyp_bias = self.manifold.expmap_0(bias, self.beta)
            res = self.manifold.mobius_add(res, hyp_bias, beta=self.beta)
        return res

    def extra_repr(self):
        return 'in_features={}, out_features={}, beta={}'.format(
            self.in_features, self.out_features, self.beta
        )


class PseudoAct(nn.Module):
    """
    Pseudo hyperboloid activation layer.
    """

    def __init__(self, manifold, beta, act):
        super(PseudoAct, self).__init__()
        self.manifold = manifold
        self.beta = beta
        self.act = act

    def forward(self, x):
        xt = self.act(self.manifold.logmap_sr_0(x, beta=self.beta))
        xt = self.manifold.proj_tan0(xt, beta=self.beta)
        return self.manifold.expmap_sr_0(xt, beta=self.beta)

    def extra_repr(self):
        return 'beta={}'.format(
            self.beta,
        )


class PseudoAgg(MessagePassing):
    """
    Pseudo hyperboloid aggregation layer using degree.
    """

    def __init__(self, manifold, beta):
        super(PseudoAgg, self).__init__()
        self.manifold = manifold
        self.beta = beta

    @staticmethod
    def norm(edge_index, num_nodes, edge_weight=None, improved=False, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1),), dtype=dtype,
                                     device=edge_index.device)

        fill_value = 1 if not improved else 2
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index=None):
        x_tangent = self.manifold.logmap_sr_0(x, beta=self.beta)
        edge_index, norm = self.norm(edge_index, x.size(0), dtype=x.dtype)
        node_i = edge_index[0]
        node_j = edge_index[1]
        x_j = torch.nn.functional.embedding(node_j, x_tangent)
        support = norm.view(-1, 1) * x_j
        support_t = scatter(support, node_i, dim=0, dim_size=x.size(0))  # aggregate the neighbors of node_i
        output = self.manifold.expmap_sr_0(support_t, beta=self.beta)
        return output

    def extra_repr(self):
        return 'beta={}'.format(self.beta)
