"""Hyperbolic layers."""
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.modules.module import Module
from torch_geometric.utils import to_dense_adj
# from lorentz import Lorentz /
from geoopt import Lorentz
# from geoopt.manifolds.stereographic.math import *
from misc.utils_math import artanh, tanh


def get_dim_act_curv(args):
    """
    Helper function to get dimension and activation at every layer.
    :param args:
    :return:
    """
    if not args.act:
        act = lambda x: x
    else:
        act = getattr(F, args.act)
    acts = [act] * (args.num_layers - 1)
    dims = [args.feat_dim] + ([args.dim] * (args.num_layers - 1))
    if args.task in ['lp', 'rec']:
        dims += [args.dim]
        acts += [act]
        n_curvatures = args.num_layers
    else:
        n_curvatures = args.num_layers - 1
    if args.c is None:
        # create list of trainable curvature parameters
        curvatures = [nn.Parameter(torch.Tensor([1.])) for _ in range(n_curvatures)]
    else:
        # fixed curvature
        curvatures = [torch.tensor([args.c]) for _ in range(n_curvatures)]
        if not args.cuda == -1:
            curvatures = [curv.to(args.device) for curv in curvatures]
    return dims, acts, curvatures


class HNNLayer(nn.Module):
    """
    Hyperbolic neural networks layer.
    """

    def __init__(self, manifold, in_features, out_features, c, dropout, act, use_bias):
        super(HNNLayer, self).__init__()
        self.linear = HypLinear(manifold, in_features, out_features, c, dropout, use_bias, scale=10)
        self.hyp_act = HypAct(manifold, c, c, act)

    def forward(self, x):
        h = self.linear.forward(x)
        h = self.hyp_act.forward(h)
        return h


class LGCNConv(nn.Module):
    """
    Hyperbolic graph convolution layer.
    """

    def __init__(self, in_features, out_features, dropout=0.0, use_att=0, local_agg=0, use_bias=1, nonlin=None, rescale=False):
        super(LGCNConv, self).__init__()

        self.linear = LorentzLinear(in_features, out_features, rescale, use_bias, dropout, nonlin=nonlin)
        self.agg = LorentzAgg(out_features, dropout, use_att, local_agg, rescale)

    def forward(self, x, adj, k):
        h = self.linear(x, k)
        h = self.agg(h, adj, k)
        output = h, adj
        return h


class LorentzLinear(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 rescale=True,
                 bias=True,
                 dropout=0.1,
                 scale=10,
                 fixscale=False,
                 nonlin=None):
        super().__init__()
        self.nonlin = nonlin
        self.in_features = in_features
        self.out_features = out_features
        self.rescale = rescale
        self.bias = bias
        self.weight = nn.Linear(self.in_features, self.out_features, bias=bias)
        self.reset_parameters()
        self.dropout = nn.Dropout(dropout)
        self.scale = nn.Parameter(torch.ones(()) * math.log(scale), requires_grad=not fixscale)
        # self.rescale_factor = (self.manifold.k / 1.0) ** 0.5

    def forward(self, x, k):
        if self.nonlin is not None:
            x = self.nonlin(x)
            
        x = self.weight(self.dropout(x)) 
        x_narrow = x.narrow(-1, 1, x.shape[-1] - 1) 
        time = x.narrow(-1, 0, 1).sigmoid() * self.scale.exp() + 1.1 
        

        if self.rescale:
            scale = (time * time - 1) / \
                (x_narrow * x_narrow).sum(dim=-1, keepdim=True).clamp_min(1e-8) 
            x = torch.cat([time, x_narrow * scale.sqrt()], dim=-1)

            self.rescale_factor = (k / 1.0) ** 0.5 
            
            x_rescale =  x * self.rescale_factor
        else:
            scale = (time * time - 1) / \
                (x_narrow * x_narrow).sum(dim=-1, keepdim=True).clamp_min(1e-8)
            x_narrow = x_narrow * scale.sqrt()
            time = ((x_narrow * x_narrow).sum(dim=-1, keepdim=True).clamp_min(1e-8) + k) ** 0.5
            x = torch.cat([time, x_narrow], dim=-1)
            x_rescale = x

        return x_rescale

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.out_features)
        step = self.in_features
        nn.init.uniform_(self.weight.weight, -stdv, stdv)
        with torch.no_grad():
            for idx in range(0, self.in_features, step):
                self.weight.weight[:, idx] = 0
        if self.bias:
            nn.init.constant_(self.weight.bias, 0)


class LorentzAgg(Module):
    """
    Lorentz aggregation layer.
    """

    def __init__(self, in_features, dropout, use_att, local_agg, rescale=True):
        super(LorentzAgg, self).__init__()
        # self.manifold = manifold

        self.in_features = in_features
        self.rescale = rescale
        self.dropout = dropout
        self.local_agg = local_agg
        self.use_att = use_att
        if self.use_att:
            self.key_linear = LorentzLinear(in_features, in_features)
            self.query_linear = LorentzLinear(in_features, in_features)
            self.bias = nn.Parameter(torch.zeros(()) + 20)
            self.scale = nn.Parameter(torch.zeros(()) + math.sqrt(in_features))
    
    def cinner(self, x: torch.Tensor, y: torch.Tensor):
        x = x.clone()
        x.narrow(-1, 0, 1).mul_(-1)
        return x @ y.transpose(-1, -2)
    
    def inner(self, u, v, keepdim: bool = True, dim: int = -1):
        d = u.size(dim) - 1
        uv = u * v
        if keepdim is False:
            return -uv.narrow(dim, 0, 1).sum(dim=dim, keepdim=False) + uv.narrow(
                dim, 1, d
            ).sum(dim=dim, keepdim=False)
        else:
            return torch.cat((-uv.narrow(dim, 0, 1), uv.narrow(dim, 1, d)), dim=dim).sum(
                dim=dim, keepdim=True
            )
    
    def project(self, x, k: torch.Tensor, dim: int = -1):
        dn = x.size(dim) - 1
        right_ = x.narrow(dim, 1, dn)
        left_ = torch.sqrt(
            k + (right_ * right_).sum(dim=dim, keepdim=True)
        )
        x = torch.cat((left_, right_), dim=dim)
        return x

    def forward(self, x, adj, k):
        if self.use_att:
            if self.local_agg:
                query = self.query_linear(x, k)
                key = self.key_linear(x, k)
                att_adj = 2 + 2 * self.cinner(query, key)
                att_adj = att_adj / self.scale + self.bias
                att_adj = torch.sigmoid(att_adj)
                att_adj = torch.mul(adj.to_dense(), att_adj)
                support_t = torch.matmul(att_adj, x)
            else:
                adj_att = self.att(x, adj)
                support_t = torch.matmul(adj_att, x)
        else:

            support_t = torch.spmm(adj, x)
            
        denom = (-self.inner(support_t, support_t))
        denom = denom.abs().clamp_min(1e-8).sqrt()
        output = support_t / denom

        if self.rescale:
            self.rescale_factor = (k / 1.0) ** 0.5 
            rescaled_output = output * self.rescale_factor
        else:
            rescaled_output = self.project(output, k)
        return rescaled_output

    def attention(self, x, adj):
        pass


class HyperbolicGraphConvolution(nn.Module):
    """
    Hyperbolic graph convolution layer.
    """

    def __init__(self, manifold, in_features, out_features, dropout, use_bias=False, use_att=False, local_agg=False, c_in=1.0, c_out=1.0):
        super(HyperbolicGraphConvolution, self).__init__()
        self.manifold = manifold
        self.linear = HypLinear(manifold, in_features, out_features, c_in, dropout, use_bias)
        self.agg = HypAgg(manifold, c_in, out_features, dropout, use_att, local_agg)
    
    def node_information_score(self, adj, h):
        D = torch.sum(adj, dim=1).pow(-1).diag_embed()
        I = torch.eye(adj.shape[0], device=adj.device)
        A_norm = I - torch.linalg.solve(D, adj)
        
        h_sigma = A_norm @ h
        
        return torch.norm(h_sigma, p=1, dim=1)

    def select_nodes(self, node_scores, selection_ratio):
        threshold = node_scores.sort(descending=True)[0][int(len(node_scores) * (1 - selection_ratio))]
        return node_scores >= threshold
    
    def get_submatrix_from_selected_nodes(self, adj_dense, selected_nodes):
        rows = torch.where(selected_nodes, adj_dense, 0)
        cols = torch.where(selected_nodes, adj_dense.T, 0)
        submatrix = rows * cols
        
        submatrix += torch.diag_embed(torch.where(selected_nodes, torch.ones_like(selected_nodes), torch.zeros_like(selected_nodes)))
        
        return submatrix
    
    def convert_to_sparse_coo_tensor(self, submatrix):
        idx = torch.nonzero(submatrix).T 
        data = submatrix[idx[0],idx[1]]
        
        coo_a = torch.sparse_coo_tensor(idx, data, submatrix.shape)
        
        return coo_a
    

    def forward(self, x, adj):
        # x, adj = input
        h = self.linear.forward(x)
        
        adj_dense = adj.to_dense()
        x1 = self.manifold.logmap0(h)
        node_scores = self.node_information_score(adj_dense, x1)  

        selection_ratio = 0.7  
        selected_nodes = self.select_nodes(node_scores, selection_ratio) 
        adj_selected = self.get_submatrix_from_selected_nodes(adj_dense, selected_nodes)
        adj = self.convert_to_sparse_coo_tensor(adj_selected)

        h = self.agg.forward(h, adj)
        output = h
        return output


class HypLinear(nn.Module):
    """
    Hyperbolic linear layer.
    """

    def __init__(self, manifold, in_features, out_features, c, dropout, use_bias):
        super(HypLinear, self).__init__()
        self.manifold = manifold
        self.in_features = in_features
        self.out_features = out_features
        self.c = c
        self.dropout = dropout
        self.use_bias = use_bias
        self.bias = nn.Parameter(torch.Tensor(out_features))
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.reset_parameters()
        self.min_norm = 1e-15
        self.eps = {torch.float32: 4e-3, torch.float64: 1e-5}

    def reset_parameters(self):
        init.xavier_uniform_(self.weight, gain=math.sqrt(2))
        init.constant_(self.bias, 0)
    
    def mobius_matvec(self, m, x, c):
        sqrt_c = c ** 0.5
        x_norm = x.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm)
        mx = x @ m.transpose(-1, -2)
        mx_norm = mx.norm(dim=-1, keepdim=True, p=2).clamp_min(self.min_norm)
        res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c)
        cond = (mx == 0).prod(-1, keepdim=True, dtype=torch.uint8)
        res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device)
        res = torch.where(cond, res_0, res_c)
        return res
    
    def mobius_add(self, x, y, c, dim=-1):
        x2 = x.pow(2).sum(dim=dim, keepdim=True)
        y2 = y.pow(2).sum(dim=dim, keepdim=True)
        xy = (x * y).sum(dim=dim, keepdim=True)
        num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y
        denom = 1 + 2 * c * xy + c ** 2 * x2 * y2
        return num / denom.clamp_min(self.min_norm)
    
    def proj(self, x, c):
        norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), self.min_norm)
        maxnorm = (1 - self.eps[x.dtype]) / (c ** 0.5)
        cond = norm > maxnorm
        projected = x / norm * maxnorm
        return torch.where(cond, projected, x)
    
    def proj_tan0(self, u, c):
        return u

    def forward(self, x):
        drop_weight = F.dropout(self.weight, self.dropout, training=self.training)
        mv = self.mobius_matvec(drop_weight, x, self.c)
        res = self.proj(mv, self.c)
        if self.use_bias:
            bias = self.proj_tan0(self.bias.view(1, -1), self.c)
            hyp_bias = self.manifold.expmap0(bias)
            hyp_bias = self.proj(hyp_bias, self.c)
            res = self.mobius_add(res, hyp_bias, c=self.c)
            res = self.proj(res, self.c)
        return res

    def extra_repr(self):
        return 'in_features={}, out_features={}, c={}'.format(
            self.in_features, self.out_features, self.c
        )


class HypAgg(Module):
    """
    Hyperbolic aggregation layer.
    """

    def __init__(self, manifold, c, in_features, dropout, use_att, local_agg):
        super(HypAgg, self).__init__()
        self.manifold = manifold
        self.c = c

        self.in_features = in_features
        self.dropout = dropout
        self.local_agg = local_agg
        self.use_att = use_att

        self.min_norm = 1e-15
        self.eps = {torch.float32: 4e-3, torch.float64: 1e-5}
            
    def proj(self, x, c):
        norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), self.min_norm)
        maxnorm = (1 - self.eps[x.dtype]) / (c ** 0.5)
        cond = norm > maxnorm
        projected = x / norm * maxnorm
        return torch.where(cond, projected, x)

    def forward(self, x, adj):
        x_tangent = self.manifold.logmap0(x)
        if self.use_att:
            if self.local_agg:
                x_local_tangent = []
                for i in range(x.size(0)):
                    x_local_tangent.append(self.manifold.logmap(x[i], x))
                x_local_tangent = torch.stack(x_local_tangent, dim=0)
                adj_att = self.att(x_tangent, adj)
                att_rep = adj_att.unsqueeze(-1) * x_local_tangent
                support_t = torch.sum(adj_att.unsqueeze(-1) * x_local_tangent, dim=1)
                output = self.manifold.proj(self.manifold.expmap(x, support_t), c=self.c)
                return output
            else:
                adj_att = self.att(x_tangent, adj)
                support_t = torch.matmul(adj_att, x_tangent)
        else:
            support_t = torch.spmm(adj, x_tangent)
        output = self.proj(self.manifold.expmap0(support_t), c=self.c)
        return output

    def extra_repr(self):
        return 'c={}'.format(self.c)


class HypAct(Module):
    """
    Hyperbolic activation layer.
    """

    def __init__(self, manifold, c_in, c_out, act):
        super(HypAct, self).__init__()
        self.manifold = manifold
        self.c_in = c_in
        self.c_out = c_out
        self.act = act

    def forward(self, x):
        xt = self.act(self.manifold.logmap0(x, c=self.c_in))
        xt = self.manifold.proj_tan0(xt, c=self.c_out)
        return self.manifold.proj(self.manifold.expmap0(xt, c=self.c_out), c=self.c_out)

    def extra_repr(self):
        return 'c_in={}, c_out={}'.format(
            self.c_in, self.c_out
        )
