import math,os
import torch
import numpy as np
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch_sparse import SparseTensor, matmul
from torch_geometric.utils import degree, remove_self_loops, add_self_loops

from script.models.NodeFormer.layers import NodeFormerConv, adj_mul

class NodeFormer(nn.Module):
    '''
    NodeFormer model implementation
    return: predicted node labels, a list of edge losses at every layer
    '''
    def __init__(self, args, nb_random_features=30, use_bn=True, use_gumbel=True,
                 use_residual=True, use_act=False, use_jk=False, nb_gumbel_sample=10, rb_trans='sigmoid', use_edge_loss=True):
        super(NodeFormer, self).__init__()

        self.in_channels = args.nfeat
        self.hidden_channels = args.nhid
        self.out_channels = args.nout
        self.num_layers = args.trans_num_layers
        self.num_heads = args.trans_num_heads
        self.rb_order = args.rb_order

        self.convs = nn.ModuleList()
        self.fcs = nn.ModuleList()
        self.fcs.append(nn.Linear(self.in_channels, self.hidden_channels))
        self.bns = nn.ModuleList()
        self.bns.append(nn.LayerNorm(self.hidden_channels))
        for i in range(self.num_layers):
            self.convs.append(
                NodeFormerConv(self.hidden_channels, self.hidden_channels, num_heads=self.num_heads,
                              nb_random_features=nb_random_features, use_gumbel=use_gumbel, nb_gumbel_sample=nb_gumbel_sample,
                               rb_order=self.rb_order, rb_trans=rb_trans, use_edge_loss=use_edge_loss))
            self.bns.append(nn.LayerNorm(self.hidden_channels))

        if use_jk:
            self.fcs.append(nn.Linear(self.hidden_channels * self.num_layers + self.hidden_channels, self.out_channels))
        else:
            self.fcs.append(nn.Linear(self.hidden_channels, self.out_channels))

        self.dropout = args.dropout
        self.activation = F.elu
        self.use_bn = use_bn
        self.use_residual = use_residual
        self.use_act = use_act
        self.use_jk = use_jk
        self.use_edge_loss = use_edge_loss
        self.feat = Parameter((torch.ones(args.num_nodes, self.in_channels)), requires_grad=True)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()
        for fc in self.fcs:
            fc.reset_parameters()

    def forward(self, edge_index, x=None, tau=1.0):
        if x is None:
            x = self.feat
        n = x.shape[0]
        adjs = []
        adj, _ = remove_self_loops(edge_index)
        adj, _ = add_self_loops(adj, num_nodes=n)
        adjs.append(adj)
        for i in range(self.rb_order - 1): # edge_index of high order adjacency # args.rb_order == 2
            adj = adj_mul(adj, adj, n)
            adjs.append(adj)

        x = x.unsqueeze(0) # [B, N, H, D], B=1 denotes number of graph
        layer_ = []
        link_loss_ = []
        z = self.fcs[0](x)
        if self.use_bn:
            z = self.bns[0](z)
        z = self.activation(z)
        z = F.dropout(z, p=self.dropout, training=self.training)
        layer_.append(z)

        for i, conv in enumerate(self.convs):
            if self.use_edge_loss:
                z, link_loss = conv(z, adjs, tau)
                link_loss_.append(link_loss)
            else:
                z = conv(z, adjs, tau)
            if self.use_residual:
                z += layer_[i]
            if self.use_bn:
                z = self.bns[i+1](z)
            if self.use_act:
                z = self.activation(z)
            z = F.dropout(z, p=self.dropout, training=self.training)
            layer_.append(z)

        if self.use_jk: # use jk connection for each layer
            z = torch.cat(layer_, dim=-1)

        x_out = self.fcs[-1](z).squeeze(0)

        return x_out, link_loss_

    def decoding_lp(self, z, edge_index):
        edge_i = edge_index[0]
        edge_j = edge_index[1]
        z_i = torch.nn.functional.embedding(edge_i, z)
        z_j = torch.nn.functional.embedding(edge_j, z)
        dist = (z_i * z_j).sum(dim=1)
        return dist
