import torch
import torch.nn as nn
from torch.nn import Parameter, Linear
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax, degree

from torch_geometric.nn.conv.gcn_conv import gcn_norm
# import torch_sparse

class LPFGraphAttLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, heads=1, concat=True,
                 negative_slope=0.2, dropout=0, bias=True, inter_dim=-1,**kwargs):
        super(LPFGraphAttLayer, self).__init__(aggr='add', **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout

        self.__alpha__ = None

        self.lin = Linear(in_channels, heads * out_channels, bias=False)

        self.att_i = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_j = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_em_i = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_em_j = Parameter(torch.Tensor(1, heads, out_channels))

        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin.weight)
        torch.nn.init.xavier_uniform_(self.att_i)
        torch.nn.init.xavier_uniform_(self.att_j)
        torch.nn.init.zeros_(self.att_em_i)
        torch.nn.init.zeros_(self.att_em_j)
        torch.nn.init.zeros_(self.bias)

    def forward(self, x, edge_index, embedding, return_attention_weights=False):
        """"""
        if torch.is_tensor(x):
            x = self.lin(x)
            x = (x, x)
        else:
            x = (self.lin(x[0]), self.lin(x[1]))

        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index,
                                       num_nodes=x[1].size(self.node_dim))

        out = self.propagate(edge_index, x=x, embedding=embedding, edges=edge_index,
                             return_attention_weights=return_attention_weights)

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.bias is not None:
            out = out + self.bias

        if return_attention_weights:
            alpha, self.__alpha__ = self.__alpha__, None
            # import numpy as np; np.save("gdn_alpha.npy", alpha.cpu().detach().numpy())
            
            return out, (edge_index, alpha)
        else:
            return out

    def message(self, x_i, x_j, edge_index_i, size_i,
                embedding,
                edges,
                return_attention_weights):

        x_i = x_i.view(-1, self.heads, self.out_channels)
        x_j = x_j.view(-1, self.heads, self.out_channels)

        if embedding is not None:
            embedding_i, embedding_j = embedding[edge_index_i], embedding[edges[0]]
            embedding_i = embedding_i.unsqueeze(1).repeat(1,self.heads,1)
            embedding_j = embedding_j.unsqueeze(1).repeat(1,self.heads,1)

            key_i = torch.cat((x_i, embedding_i), dim=-1)
            key_j = torch.cat((x_j, embedding_j), dim=-1)



        cat_att_i = torch.cat((self.att_i, self.att_em_i), dim=-1)
        cat_att_j = torch.cat((self.att_j, self.att_em_j), dim=-1)

        alpha = (key_i * cat_att_i).sum(-1) + (key_j * cat_att_j).sum(-1)


        alpha = alpha.view(-1, self.heads, 1)


        alpha = F.leaky_relu(alpha, self.negative_slope)
        self.node_dim=0
        alpha = softmax(alpha, edge_index_i, num_nodes=size_i)

        if return_attention_weights:
            self.__alpha__ = alpha

        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return x_j * alpha.view(-1, self.heads, 1)
    

class LPFGraphLayer(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(LPFGraphLayer, self).__init__()

        self.linear = nn.Linear(in_channel,out_channel)
        self.bn = nn.BatchNorm1d(out_channel)
        self.relu = nn.ReLU()

    def forward(self, x, edge_index):
        x = self.linear(x)

        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index,
                                       num_nodes=x.shape[0])
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        edge_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        size = int(edge_index.max()) + 1
        size = (size, size)
        # edge_weight = edge_weight.squeeze()
        adj = torch.sparse_coo_tensor(
            indices=edge_index,
            values=edge_weight,
            size=tuple(size) + edge_weight.size()[1:],
            device=edge_index.device)
        out = torch.sparse.mm(adj,x)
        # out = self.bn(out)
        # out = self.relu(out)
        return out
    
class LPFAttLayer(nn.Module):
    def __init__(self, in_channels, out_channels, heads=1, concat=True,
                 negative_slope=0.2, dropout=0, bias=True, inter_dim=-1,**kwargs):
        super(LPFAttLayer, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout

        self.__alpha__ = None

        self.lin = Linear(in_channels, heads * out_channels, bias=False)

        self.att = Parameter(torch.Tensor(2*out_channels,1))
        self.att_em = Parameter(torch.Tensor(2*out_channels,1))

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin.weight)
        torch.nn.init.xavier_uniform_(self.att)
        torch.nn.init.zeros_(self.att_em)

    def forward(self, x, edge_index, embedding, return_attention_weights=False):
        """"""
        num_nodes=x.shape[0]
        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index,
                                       num_nodes=num_nodes)
        size = int(edge_index.max()) + 1
        size = (size, size)
        edge_weight = torch.ones(edge_index.shape[1])
        adj = torch.sparse_coo_tensor(
            indices=edge_index,
            values=edge_weight,
            size=tuple(size) + edge_weight.size()[1:],
            device=edge_index.device)
        adj = adj.to_dense()

        x = self.lin(x)
        a_input = torch.cat([x.repeat(1,num_nodes).view(num_nodes*num_nodes,-1), x.repeat(num_nodes,1)], dim=-1).view(num_nodes, -1, 2 * self.out_channels)
        ###### add embedding
        emb_input = torch.cat([embedding.repeat(1,num_nodes).view(num_nodes*num_nodes,-1), embedding.repeat(num_nodes,1)], dim=-1).view(num_nodes, -1, 2 * self.out_channels)
        a_input = torch.cat((a_input,emb_input),dim=-1)
        
        # alpha = torch.matmul(a_input, self.att).squeeze(-1)
        alpha = torch.matmul(a_input, torch.cat((self.att,self.att_em),dim=0)).squeeze(-1)
        alpha = F.leaky_relu(alpha, self.negative_slope)

        zero_vec = -9e15 * torch.ones_like(alpha)
        attention = torch.where(adj > 0, alpha, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        out = torch.matmul(attention, x)

        return out

class SpecialSpmmFunction(torch.autograd.Function):
    """Special function for only sparse region backpropataion layer."""
    @staticmethod
    def forward(ctx, indices, values, shape, b):
        assert indices.requires_grad == False
        a = torch.sparse_coo_tensor(indices, values, shape)
        ctx.save_for_backward(a, b)
        ctx.N = shape[0]
        return torch.matmul(a, b)

    @staticmethod
    def backward(ctx, grad_output):
        a, b = ctx.saved_tensors
        grad_values = grad_b = None
        if ctx.needs_input_grad[1]:
            grad_a_dense = grad_output.matmul(b.t())
            edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :]
            grad_values = grad_a_dense.view(-1)[edge_idx]
        if ctx.needs_input_grad[3]:
            grad_b = a.t().matmul(grad_output)
        return None, grad_values, None, grad_b


class SpecialSpmm(nn.Module):
    def forward(self, indices, values, shape, b):
        return SpecialSpmmFunction.apply(indices, values, shape, b)

class SpLPFAttLayer(nn.Module):
    def __init__(self, in_channels, out_channels, heads=1, concat=True,
                 negative_slope=0.2, dropout=0, bias=True, inter_dim=-1,**kwargs):
        super(SpLPFAttLayer, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout

        self.__alpha__ = None

        self.lin = Linear(in_channels, heads * out_channels, bias=False)

        self.att = Parameter(torch.Tensor(1,2*out_channels))
        self.att_em = Parameter(torch.Tensor(1,2*out_channels))

        self.special_spmm = SpecialSpmm()


        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin.weight)
        torch.nn.init.xavier_uniform_(self.att)
        torch.nn.init.zeros_(self.att_em)

    def forward(self, x, edge_index, embedding, return_attention_weights=False):
        """"""
        num_nodes=x.shape[0]
        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index,
                                       num_nodes=num_nodes)
        size = int(edge_index.max()) + 1

        x = self.lin(x)
        edge_x = torch.cat((x[edge_index[0, :], :], x[edge_index[1, :], :]), dim=1).t()
        edge_x_emb = torch.cat((embedding[edge_index[0, :], :], embedding[edge_index[1, :], :]), dim=1).t()
        edge_x = torch.cat((edge_x,edge_x_emb),dim=0)
        att = torch.cat((self.att,self.att_em),dim=-1)

        edge_e = torch.exp(F.leaky_relu(att.mm(edge_x).squeeze(), self.negative_slope))
        assert not torch.isnan(edge_e).any()
        e_rowsum = self.special_spmm(edge_index, edge_e, torch.Size([size, size]), torch.ones(size=(size,1), device=edge_index.device))
        edge_e = F.dropout(edge_e, self.dropout, training=self.training)
        x_prime = self.special_spmm(edge_index, edge_e, torch.Size([size, size]), x)
        assert not torch.isnan(x_prime).any()
        out = x_prime.div(e_rowsum)
        
        assert not torch.isnan(out).any()

        return out