import torch.nn as nn
import torch
import math
import numpy as np
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import torch_sparse

class SineEncoding(nn.Module):
    def __init__(self, hidden_dim=128, hops_num=1):
        super(SineEncoding, self).__init__()
        self.constant = 100
        self.hidden_dim = hidden_dim
        self.adj_w = nn.Linear(hidden_dim*3, hidden_dim)
        self.hops = hops_num

    def forward(self, k_hop_neighbor):
        a = k_hop_neighbor * self.constant
        div = torch.exp(self.hops * torch.ones_like(a) * (-math.log(10000)/self.hidden_dim)).to(k_hop_neighbor.device)
        pe = a * div
        b = torch.cat((a, torch.sin(pe), torch.cos(pe)), dim=1)
        b = self.adj_w(b)
        return b


class FeedForwardNetwork(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim):
        super(FeedForwardNetwork, self).__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.gelu = nn.GELU()
        self.layer2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.layer1(x)
        x = self.gelu(x)
        x = self.layer2(x)
        return x


class MPformer(nn.Module):
    def __init__(self,
                 hops,
                 nclass,
                 nfeat,
                 nlayer=1,
                 nhid=128,
                 nheads=1,
                 nodes=100,
                 tran_dropout=0.0,
                 feat_dropout=0.0):
        super(MPformer, self).__init__()
        self.nfeat = nfeat
        self.nlayer = nlayer
        self.nheads = nheads
        self.nhid = nhid
        self.hops = hops
        self.k = 2

        self.act = F.relu
        self.initialized = False
        self.linear = nn.Linear(nfeat, nhid)
        self.decoder_1 = nn.Linear(nhid, nclass)

        self.encoding = nn.Linear(nhid*7, nhid)

        self.mha_norm = nn.LayerNorm(nhid)
        self.ffn_norm = nn.LayerNorm(nhid)
        self.mha_dropout = nn.Dropout(tran_dropout)
        self.ffn_dropout = nn.Dropout(tran_dropout)
        self.mha = nn.MultiheadAttention(nhid, nheads, tran_dropout)
        self.ffn = FeedForwardNetwork(nhid, nhid, nhid)

        self.feat_dp1 = nn.Dropout(feat_dropout)
        self.dropout = 0.5
        

        self.a1 = None
        self.a2 = None

        self.cls = nn.Parameter(torch.randn(nodes, 1)).cuda()
        self.hops_encoder_1 = SineEncoding(nhid, 1)
        self.hops_encoder_3 = SineEncoding(nhid*2, 3)
        self.hops_encoder_4 = SineEncoding(nhid*4, 4)
        

    @staticmethod
    def _indicator(sp_tensor: torch.sparse.Tensor) -> torch.sparse.Tensor:
        csp = sp_tensor.coalesce()
        return torch.sparse_coo_tensor(
            indices=csp.indices(),
            values=torch.where(csp.values() > 0, 1, 0),
            size=csp.size(),
            dtype=torch.float
        )

    @staticmethod
    def _spspmm(sp1: torch.sparse.Tensor, sp2: torch.sparse.Tensor) -> torch.sparse.Tensor:
        assert sp1.shape[1] == sp2.shape[0], 'Cannot multiply size %s with %s' % (sp1.shape, sp2.shape)
        sp1, sp2 = sp1.coalesce(), sp2.coalesce()
        index1, value1 = sp1.indices(), sp1.values()
        index2, value2 = sp2.indices(), sp2.values()
        m, n, k = sp1.shape[0], sp1.shape[1], sp2.shape[1]
        indices, values = torch_sparse.spspmm(index1, value1, index2, value2, m, n, k)
        return torch.sparse_coo_tensor(
            indices=indices,
            values=values,
            size=(m, k),
            dtype=torch.float
        )

    @classmethod
    def _adj_norm(cls, adj: torch.sparse.Tensor) -> torch.sparse.Tensor:
        n = adj.size(0)
        d_diag = torch.pow(torch.sparse.sum(adj, dim=1).values(), -0.5)
        d_diag = torch.where(torch.isinf(d_diag), torch.full_like(d_diag, 0), d_diag)
        d_tiled = torch.sparse_coo_tensor(
            indices=[list(range(n)), list(range(n))],
            values=d_diag,
            size=(n, n)
        )
        return cls._spspmm(cls._spspmm(d_tiled, adj), d_tiled)

    def _prepare_prop(self, adj):
        n = adj.size(0)
        device = adj.device
        self.initialized = True
        sp_eye = torch.sparse_coo_tensor(
            indices=[list(range(n)), list(range(n))],
            values=[1.0] * n,
            size=(n, n),
            dtype=torch.float
        ).to(device)
        # initialize A1, A2
        a1 = self._indicator(adj - sp_eye)
        a2 = self._indicator(self._spspmm(adj, adj) - adj - sp_eye)

        # norm A1 A2
        self.a1 = self._adj_norm(a1)
        self.a2 = self._adj_norm(a2)
        
    def forward(self, x, adj):
        if not self.initialized:
            self._prepare_prop(adj)

        x = self.feat_dp1(x)
        x = self.linear(x)

        r = [x]
        r_last = r[-1]
        r1 = torch.spmm(self.a1, r_last) # 1-hop
        r2 = torch.spmm(self.a2, r_last) # 2-hop
        r[-1] = self.hops_encoder_1(r_last) 
        r.append(self.act(torch.cat([r1, r2], dim=1)))

        r_last = r[-1]
        r1 = torch.spmm(self.a1, r_last) 
        r2 = torch.spmm(self.a2, r_last)
        r[-1] = self.hops_encoder_3(r_last) 
        r_last = self.hops_encoder_4(torch.cat([r1, r2], dim=1))
        r.append(self.act(r_last))

        r_final = torch.cat(r, dim=1)
        r_final = F.dropout(r_final, self.dropout, training=self.training)

        feat = self.encoding(r_final)
        mha_h = self.mha_norm(feat)
        mha_h, atten = self.mha(mha_h, mha_h, mha_h)
        mha_h = self.mha_dropout(mha_h)
        mha_h = mha_h + feat

        ffn_h = self.ffn_norm(mha_h)
        ffn_h = self.ffn(ffn_h)
        ffn_h = self.ffn_dropout(ffn_h) + mha_h
        new_feat = self.decoder_1(ffn_h)

        return torch.log_softmax(new_feat, dim=1)
        

if __name__ == '__main__':
    pass






