r"""modified from torch_geometric"""
from networkx import from_scipy_sparse_array as from_scipy_sparse_matrix
import torch
from torch_geometric.utils import dense_to_sparse
from torch_geometric.nn import GCNConv, EdgeConv, GINConv
from torch.nn import Embedding, Sequential, Linear, ModuleList, ReLU, LeakyReLU
from torch.nn.parameter import Parameter
import numpy as np
import torch.nn.functional as F


def target_distribution(q):
    weight = q ** 2 / q.sum(0)
    return (weight.t() / weight.sum(1)).t()
    

class GraphConv(torch.nn.Module):
    def __init__(self):
        super(GraphConv, self).__init__()

    def forward(self, x, adj):
        x = torch.mm(adj, x)
        return x


class Rewrite_GCNEncoder(torch.nn.Module):
    def __init__(self, sample_num,in_channels, hidden_channels, out_channels, num_layers=4, aggr='add'):
        super(Rewrite_GCNEncoder, self).__init__()
        self.num_layers = num_layers
        n_label=2
        # self.A_tilde = Parameter(torch.Tensor(sample_num, sample_num))
        # self.aggregate = Linear(sample_num, sample_num, bias = False)
        self.propagate_list = ModuleList()
        self.aggregate_gcn = GraphConv()
        for i in range(num_layers):
            if i:
                nn = Sequential(Linear(hidden_channels, hidden_channels))
            else:
                nn = Sequential(Linear(in_channels, hidden_channels))
            self.propagate_list.append(nn)
        self.propagate_list.append(torch.nn.Linear(hidden_channels, out_channels))
        self.cluster = torch.nn.Linear(out_channels, out_channels)
        self.cluster_layer = Parameter(torch.Tensor(n_label, out_channels))
        self.alpha = 1.0
        self.mask_logits = Parameter(torch.zeros(sample_num, sample_num))   
        self.masked_A = None
        self.register_parameter("mask_logits", self.mask_logits)
        self.register_parameter("cluster_layer", self.cluster_layer)
        self.I = torch.eye(sample_num, sample_num)
        # torch.nn.init.xavier_normal_(self.mask_logits.data)

    def forward(self, x, A, mode='train'):
        mask = torch.sigmoid(self.mask_logits)
        self.I = self.I.to(A.device)              
        masked_A = mask * A + self.I
        self.masked_A = masked_A
        first_emb = []
        for i in range(len(self.propagate_list)):
            x=self.aggregate_gcn(x, masked_A)
            if i==0:
                first_emb.append(x)
            x = F.leaky_relu(self.propagate_list[i](x))
        z = self.cluster(x)
        q = 1.0 / (1.0 + torch.sum(
            torch.pow(z.unsqueeze(1) - self.cluster_layer, 2), 2) / self.alpha)
        q = q.pow((self.alpha + 1.0) / 2.0)
        q = (q.t() / torch.sum(q, 1)).t()
        return first_emb, z, q
    

def from_adj_to_edge(A):
    edge_index,edge_weight = dense_to_sparse(A)
    return edge_index, edge_weight 

def from_edge_to_adj(edge_index,edge_attr,num_nodes):
    return torch.geometric.utils.to_dense_adj(edge_index, edge_attr=edge_attr, num_nodes=num_nodes)