import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import manifolds
import utils.math_utils as pmath
import torch as th
import geotorch as geoth
from utils import *
from utils import pre_utils
from utils.pre_utils import *
from manifolds import *
#from manifolds import LorentzManifold
from layers.CentroidDistance import CentroidDistance



class SRBGCN(nn.Module):

    def __init__(self, args, logger):
        super(SRBGCN, self).__init__()
        self.debug = False
        self.args = args
        self.logger = logger
        
        tie_list = [0]*len(args.dim)
        counter = 0;
        for i in range(1,len(args.dim)):
            if args.dim[i] == args.dim[i-1]:
                tie_list[i] = counter
            else:
                tie_list[i] = counter + 1
                counter = counter + 2
        self.tie_list = tie_list            
        self.eps = 1e-14
        
        if not self.args.tie_weight:
            self.tie_list = [i for i in range(len(args.dim))]
        
        self.set_up_boost_params()
        self.activation = nn.SELU()
        
        
        self.linear = nn.Linear(
                int(args.feature_dim), int(args.dim[0]),
        )
        nn_init(self.linear, self.args.proj_init)
        self.args.variables.append(self.linear)

        self.msg_weight = []
        layer = self.args.num_layers if not self.args.tie_weight else self.tie_list[-1]+1

        for iii in range(layer):
            if iii == 0:
                M =  nn.Linear(self.args.dim[self.tie_list.index(iii)]-1, self.args.dim[self.tie_list.index(iii)]-1,bias=0)
            else:
                M = nn.Linear(self.args.dim[self.tie_list.index(iii)]-1, self.args.dim[self.tie_list.index(iii-1)]-1,bias=0)
         
            geoth.orthogonal(M,"weight","cayley")

            self.args.variables.append(M)
            self.msg_weight.append(M)
        self.msg_weight = nn.ModuleList(self.msg_weight)
        
        if self.args.task == 'nc':
            self.distance = CentroidDistance(args, logger, args.manifold)

       
    '''    
    def create_params_h(self):
        """
        create the GNN params for hyperbolic rotation transformation using basic hyperbolic rotations
        """
        h_weight = []
        layer = self.args.num_layers if not self.args.tie_weight else self.tie_list[-1]+1
        for iii in range(layer):
            H = th.randn(self.args.dim[self.tie_list.index(iii)]-1, requires_grad=True)*0.01
            nn.init.uniform_(H, -0.001, 0.001)
            H = nn.Parameter(H)
            self.args.variables.append(H)
            h_weight.append(H)
        return nn.ParameterList(h_weight)
    '''    
    
    
    def create_boost_params(self):
        """
        create the GNN params for hyperbolic rotation transformation using axis and hyperbolic angle 
        """
        h_weight = []
        layer = self.args.num_layers if not self.args.tie_weight else self.tie_list[-1]+1
        for iii in range(layer):
            H = th.randn(self.args.dim[self.tie_list.index(iii)], requires_grad=True)*0.01
            H = th.ones(self.args.dim[self.tie_list.index(iii)], requires_grad=True)/np.sqrt(self.args.dim[self.tie_list.index(iii)]-1)
            H[0] = 0
            H = nn.Parameter(H)
            self.args.variables.append(H)
            h_weight.append(H)
        return nn.ParameterList(h_weight)
        

    def set_up_boost_params(self):
        """
        set up the params for all message types
        """
        self.type_of_msg = 1
        
        for i in range(0, self.type_of_msg):
            if self.args.hyp_rotation:
                setattr(self, "msg_%d_weight_h" % i, self.create_boost_params())

    def apply_activation(self, node_repr):
        """
        apply non-linearity for different manifolds
        """
        if self.args.select_manifold in {"poincare", "euclidean"}:
            return self.activation(node_repr)
        elif self.args.select_manifold == "lorentz":
            return self.args.manifold.from_poincare_to_lorentz(
                self.activation(self.args.manifold.from_lorentz_to_poincare(node_repr))
            )

    def split_graph_by_negative_edge(self, adj_mat, weight):
        """
        Split the graph according to positive and negative edges.
        """
        mask = weight > 0
        neg_mask = weight < 0

        pos_adj_mat = adj_mat * mask.long()
        neg_adj_mat = adj_mat * neg_mask.long()
        pos_weight = weight * mask.float()
        neg_weight = -weight * neg_mask.float()
        return pos_adj_mat, pos_weight, neg_adj_mat, neg_weight

    def split_graph_by_type(self, adj_mat, weight):
        """
        split the graph according to edge type for multi-relational datasets
        """
        multi_relation_adj_mat = []
        multi_relation_weight = []
        for relation in range(1, self.args.edge_type):
            mask = (weight.int() == relation)
            multi_relation_adj_mat.append(adj_mat * mask.long())
            multi_relation_weight.append(mask.float())
        return multi_relation_adj_mat, multi_relation_weight

    def split_input(self, adj_mat, weight):
        return [adj_mat], [weight]

    def p2k(self, x, c):
        denom = 1 + c * x.pow(2).sum(-1, keepdim=True)
        return 2 * x / denom


    def lorenz_factor(self, x, *, c=1.0, dim=-1, keepdim=False):
        """
            Calculate Lorenz factors
        """
        x_norm = x.pow(2).sum(dim=dim, keepdim=keepdim)
        x_norm = torch.clamp(x_norm, 0, 0.9)
        tmp = 1 / torch.sqrt(1 - c * x_norm)
        return tmp
     
    def from_lorentz_to_poincare(self, x):
        """
        Args:
            u: [batch_size, d + 1]
        """
        d = x.size(-1) - 1
        return x.narrow(-1, 1, d) / (x.narrow(-1, 0, 1) + 1)

    def h2p(self, x):
        return self.from_lorentz_to_poincare(x)

    def from_poincare_to_lorentz(self, x, eps=1e-3):
        """
        Args:
            u: [batch_size, d]
        """
        x_norm_square = x.pow(2).sum(-1, keepdim=True)
        tmp = th.cat((1 + x_norm_square, 2 * x), dim=1)
        tmp = tmp / (1 - x_norm_square)
        return  tmp

    def p2h(self, x):
        return  self.from_poincare_to_lorentz(x)

    def p2k(self, x, c=1.0):
        denom = 1 + c * x.pow(2).sum(-1, keepdim=True)
        return 2 * x / denom

    
    def test_lor(self, A):
        tmp1 = (A[:,0] * A[:,0]).view(-1)
        tmp2 = A[:,1:]
        tmp2 = th.diag(tmp2.mm(tmp2.transpose(0,1)))
        return (tmp1 - tmp2)
    
    def lorentz_mean(self, y, node_num, max_neighbor, real_node_num, weight, dim=0, c=1.0, ):
        '''
        y [node_num * max_neighbor, dim]
        '''
        
        nu_sum = torch.mul(y.transpose(-2,-1).view(-1, node_num, max_neighbor),weight).sum(-1).transpose(-2, -1)
        l_dot = self.args.manifold.minkowski_dot(nu_sum,nu_sum,keepdim=False)
        coef = torch.sqrt(c / torch.abs(l_dot))

        mean = torch.mul(coef, nu_sum.transpose(-2, -1)).transpose(-2, -1)

        return mean

    

    def retrieve_params(self, weight, step):
        """
        Args:
            weight: a list of weights
            step: a certain layer
        """
        weight = weight[step].weight
        
        layer_weight = th.cat((th.zeros((weight.size(0), 1)).cuda().to(self.args.device), weight), dim=1)
        tmp = th.zeros((1, weight.size(1)+1)).cuda().to(self.args.device)
        tmp[0,0] = 1
        layer_weight = th.cat((tmp, layer_weight), dim=0)
        return layer_weight
    
    def retrieve_params_h(self, weight_h, step):
        """
        retrieve the GNN parameters for hyperbolic rotation using an axis and hyperbolic angle
        Args:
            weight_h: a list of weights
            step: a certain layer
        """
        tmp = weight_h[step]
        v_d = tmp[1:]
        n_d = v_d/th.sqrt(v_d.pow(2).sum(-1, keepdim=True)+self.eps)
        C = th.eye(tmp.size(0)-1).cuda().to(self.args.device)-(1-pmath.cosh(tmp[0]))*th.outer(n_d,n_d)
        layer_weight = th.cat((pmath.sinh(tmp[0])*n_d.reshape((-1, 1)), C), dim=1)
        aB = th.cat([pmath.cosh(tmp[0]).reshape(1),pmath.sinh(tmp[0])*n_d])
        layer_weight = th.cat((aB.reshape((1, -1)), layer_weight), dim=0)
        return layer_weight
    
    def aggregate_msg(self, node_repr, adj_mat, weight, layer_weight, layer_weight_h, mask, c):
        """
        message passing for a specific message type.
        """
        node_num, max_neighbor = adj_mat.shape[0], adj_mat.shape[1] 
        combined_msg = node_repr.clone()

        
        msg = th.mm(node_repr, layer_weight.T)
      
        if(self.args.hyp_rotation):
            '''
            #basic rotations
            for i in range(layer_weight_h.size(0)):
                msg = th.spmm(msg,layer_weight_h[i])
            '''
            msg = th.mm(msg, layer_weight_h)
        
        
        msg = msg * mask
        real_node_num = (mask>0).sum()

        neighbors = th.index_select(msg, 0, adj_mat.view(-1)) 

        combined_msg = self.lorentz_mean(neighbors, node_num, max_neighbor, real_node_num, weight, c = c)
        
        return combined_msg 

    def get_combined_msg(self, step, node_repr, adj_mat, weight, mask, c):
        """
        perform message passing in the tangent space of x'
        """
        gnn_layer = self.tie_list[step] if self.args.tie_weight else step
        combined_msg = None
        for relation in range(0, self.type_of_msg):
            layer_weight = self.retrieve_params(self.msg_weight, gnn_layer)
            layer_weight_h = self.retrieve_params_h(getattr(self, "msg_%d_weight_h" % relation), gnn_layer) if self.args.hyp_rotation else None
            aggregated_msg = self.aggregate_msg(node_repr,
                                                adj_mat[relation],
                                                weight[relation],
                                                layer_weight,layer_weight_h, mask, c)
            combined_msg = aggregated_msg if combined_msg is None else (combined_msg + aggregated_msg)
        return combined_msg


    def encode(self, node_repr, adj_list, weight, c):
        node_repr = self.activation(self.linear(node_repr))
        adj_list, weight = self.split_input(adj_list, weight)
        mask = torch.ones((node_repr.size(0),1)).cuda().to(self.args.device)
        node_repr = self.args.manifold.exp_map_zero(node_repr, c)

        for step in range(self.args.num_layers):
            node_repr = node_repr * mask
            tmp = node_repr
            combined_msg = self.get_combined_msg(step, node_repr, adj_list, weight, mask, c)
            combined_msg = (combined_msg) * mask
            node_repr = combined_msg * mask
            node_repr = self.apply_activation(node_repr) * mask
            real_node_num = (mask>0).sum()
            
            node_repr = self.args.manifold.normalize(node_repr, c)
            tmp = self.test_lor(node_repr)
        if self.args.task == 'nc':
            _, node_centroid_sim = self.distance(node_repr, mask, c) 
            return node_centroid_sim.squeeze(),node_repr
        return node_repr, node_repr

class Encoder(nn.Module):
    """
    Encoder abstract class.
    """

    def __init__(self, c):
        super(Encoder, self).__init__()
        self.c = c

    def encode(self, x, adj):
        if self.encode_graph:
            input = (x, adj)
            output, _ = self.layers.forward(input)
        else:
            output = self.layers.forward(x)
        return output
