import torch
from torch.nn import Parameter, Linear
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.utils import add_remaining_self_loops
import networkx as nx

from torch_geometric.nn.inits import glorot, zeros
import torch_scatter
from torch import Tensor
import numpy as np
from scipy import sparse
from scipy.sparse import coo_matrix
def compute_ppr(a, self_loop=True):
    a = a.to_dense()
    a_new=torch.mm(a,a)
    a_new =a_new-a-torch.eye(a.shape[0]).cuda(1)
    return a_new
def get_g2(edge_index,feature):
    A=edge_index
    value=torch.ones(edge_index.shape[1],dtype=torch.float32).cuda(1)
    A = torch.sparse_coo_tensor(indices=A, values=value , size=[feature.shape[0], feature.shape[0]])

    A=compute_ppr(A)
    zero = torch.zeros_like(A)
    one = torch.ones_like(A)
    new_A1=torch.where(A < 0,zero,A)
    new_A=torch.where(new_A1 > 1,one,new_A1)
    new_A=new_A-torch.eye(new_A.shape[0]).cuda(1)
    new_A=torch.where(new_A < 0,zero,new_A).cpu()
    adj=coo_matrix(new_A)
    row = torch.from_numpy(adj.row.astype(np.int64)).to(torch.long)
    col = torch.from_numpy(adj.col.astype(np.int64)).to(torch.long)
    edge_index = torch.stack([row, col], dim=0).cuda(1)
    return edge_index


def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
             add_self_loops=True, dtype=None):

    fill_value = 2. if improved else 1.



    if edge_weight is None:
        edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                 device=edge_index.device)

    if add_self_loops:
        edge_index, tmp_edge_weight = add_remaining_self_loops(
             edge_index, edge_weight, fill_value, num_nodes)
        assert tmp_edge_weight is not None
        # edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
        #                          device=edge_index.device)
        edge_weight = tmp_edge_weight

    row, col = edge_index[0], edge_index[1]
    deg = torch_scatter.scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
    return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col], deg


class DMPConv(MessagePassing):
    def __init__(self, in_channels, out_channels, heads=1, concat=True,
                 negative_slope=0.2, dropout=0, bias=True, K=5,**kwargs):
        super(DMPConv, self).__init__(aggr='add')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.K=K
        self.weight = Parameter(
            torch.Tensor(in_channels, heads * out_channels))
        self.weight2 = Parameter(
            torch.Tensor(heads * out_channels*(2*K+1), heads * out_channels))
        self.feat_att = Parameter(torch.Tensor(out_channels, heads, out_channels))
        self.feat_att_bias = Parameter(torch.Tensor(heads, out_channels))
        self.feat_att2 = Parameter(torch.Tensor(out_channels, heads, out_channels))
        self.feat_att_bias2 = Parameter(torch.Tensor(heads, out_channels))
        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
            self.bias2 = Parameter(torch.Tensor(heads * out_channels))

        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
            self.bias2 = Parameter(torch.Tensor(out_channels))

        else:
            self.register_parameter('bias', None)

        self.reset_parameters()
        self._cached_edge_index = None
        self._cached_edge_index2 = None
        self._cached_adj_t = None
        self.lambda_ = kwargs['lambda_']
        self.g2=False
    def reset_parameters(self):
        glorot(self.weight)
        glorot(self.feat_att)
        zeros(self.bias)
        zeros(self.feat_att_bias)
        glorot(self.weight2)
        glorot(self.feat_att2)
        zeros(self.bias2)
        zeros(self.feat_att_bias2)

    def forward(self, x, edge_index, size=None):
        """"""
        x=x
        if isinstance(edge_index, Tensor):
            cache = self._cached_edge_index
            if cache is None:
                edge_weight = None
                edge_index, self.edge_weight, self.deg = gcn_norm(  # yapf: disable
                    edge_index, edge_weight, x.size(0),
                    False, add_self_loops=True, dtype=x.dtype)
                self.deg = self.deg.view(-1, 1, 1)
                # if self.cached:
                #     self._cached_edge_index = (edge_index, edge_weight)
            else:
                edge_index, self.edge_weight = cache[0], cache[1]

        if isinstance(edge_index, Tensor):
            cache2 = self._cached_edge_index2
            if cache2 is None:
                edge_index2=get_g2(edge_index, x)
                edge_weight2 = None
                edge_index2, self.edge_weight2, self.deg2 = gcn_norm(  # yapf: disable
                    edge_index2, edge_weight2, x.size(0),
                    False, add_self_loops=True, dtype=x.dtype)
                self.deg2 = self.deg2.view(-1, 1, 1)           
                self._cached_edge_index2 = (edge_index2,  self.edge_weight2)
            else:
                edge_index2, self.edge_weight2 = cache2[0], cache2[1]


        if torch.is_tensor(x):
            x = torch.matmul(x, self.weight)
        else:
            x = (None if x[0] is None else torch.matmul(x[0], self.weight),
                 None if x[1] is None else torch.matmul(x[1], self.weight))

        self.x = x.view(-1, self.heads, self.out_channels)
        self.x2 = x.view(-1, self.heads, self.out_channels)


         
        x2=x
        x_final=x
        for i in range(self.K):
            for j in range(2):
                if self.g2==False:
                    x= self.propagate(edge_index, size=size, x=x)
                    self.g2=True
                else:
                    x2= self.propagate(edge_index2,size=size,x=x2)
                    self.g2=False
            x1 = torch.cat([x, x2], dim=-1)
            x_final=torch.cat([x_final,x1],dim=-1)
        x_final = torch.matmul(x_final, self.weight2)
        return x_final

    def message(self, edge_index_i, x_i, x_j):
        # Compute attention coefficients.
        if self.g2==False:
            x_j = self.edge_weight.view(-1, 1) * x_j
            x_j = x_j.view(-1, self.heads, self.out_channels)

            if x_i is not None:
                x_i = x_i.view(-1, self.heads, self.out_channels)

            # h_i=F.softmax(x_i,dim=2)
            # h_j=F.softmax(x_j,dim=2)
            # a=h_i*h_j
            a=x_i*x_j
            # a=torch.ones(x_i.shape[0],x_i.shape[1],x_i.shape[2]).cuda(1)
            def mapper(idx):
                # alpha = (torch.cat([x_i, x_j], dim=-1) * self.feat_att[idx]).sum(dim=-1) + self.feat_att_bias[:, idx]
                alpha = (a * self.feat_att[idx]).sum(dim=-1) + self.feat_att_bias[:, idx]
                alpha = self.lambda_ * torch.tanh(alpha)
                return alpha

            self.save_alpha = alpha = torch.stack([mapper(idx) for idx in range(self.out_channels)], dim=-1)

            self.alpha = torch_scatter.scatter_mean(alpha, edge_index_i, dim=0)
            # Sample attention coefficients stochastically.
            alpha = F.dropout(alpha, p=self.dropout, training=self.training)


            return x_j * alpha.view(-1, self.heads, self.out_channels)
        else:
            x_j = self.edge_weight2.view(-1, 1) * x_j
            x_j = x_j.view(-1, self.heads, self.out_channels)

            if x_i is not None:
                x_i = x_i.view(-1, self.heads, self.out_channels)

            # h_i=F.softmax(x_i,dim=2)
            # h_j=F.softmax(x_j,dim=2)
            # a2=h_i*h_j
            a2=x_i*x_j
            # a2=torch.ones(x_i.shape[0],x_i.shape[1],x_i.shape[2]).cuda(1)
            def mapper(idx):
                # alpha2 = (torch.cat([x_i, x_j], dim=-1) * self.feat_att2[idx]).sum(dim=-1) + self.feat_att_bias2[:, idx]
                alpha2 = (a2 * self.feat_att2[idx]).sum(dim=-1) + self.feat_att_bias2[:, idx]
                alpha2 = self.lambda_ * torch.sigmoid(alpha2)
                return alpha2

            alpha2 = torch.stack([mapper(idx) for idx in range(self.out_channels)], dim=-1)

            self.alpha2 = torch_scatter.scatter_mean(alpha2, edge_index_i, dim=0)
            # Sample attention coefficients stochastically.
            alpha2 = F.dropout(alpha2, p=self.dropout, training=self.training)

            return x_j * alpha2.view(-1, self.heads, self.out_channels)


    def update(self, aggr_out):
        if self.g2==False:
            aggr_out = self.x - self.alpha * self.x + aggr_out
            if self.concat is True:
                aggr_out = aggr_out.view(-1, self.heads * self.out_channels)
            else:
                aggr_out = aggr_out.mean(dim=1)

            if self.bias is not None:
                aggr_out = aggr_out + self.bias
            return aggr_out
        else:
            aggr_out = self.x2 - self.alpha2 * self.x2 + aggr_out
            if self.concat is True:
                aggr_out = aggr_out.view(-1, self.heads * self.out_channels)
            else:
                aggr_out = aggr_out.mean(dim=1)

            if self.bias is not None:
                aggr_out = aggr_out + self.bias2
            return aggr_out
    def __repr__(self):
        return '{}({}, {}, heads={})'.format(self.__class__.__name__,
                                             self.in_channels,
                                             self.out_channels, self.heads)




