import scipy.sparse as sp
import torch
import sys
import pickle
import networkx as nx
from torch_geometric.utils.convert import to_scipy_sparse_matrix
import numpy as np
import json
from scipy.sparse import coo_matrix
#ref: https://discuss.pytorch.org/t/creating-a-sparse-tensor-from-csr-matrix/13658/4


def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

def GCN_diffusion(sptensor,order,feature,device='cuda'):
    """
    Creating a normalized adjacency matrix with self loops.
    sptensor = W
    https://arxiv.org/pdf/1609.02907.pdf
    """
    I_n = sp.eye(sptensor.size(0))
    I_n = sparse_mx_to_torch_sparse_tensor(I_n).to(device)
    A_gcn = sptensor +  I_n
    degrees = torch.sparse.sum(A_gcn,0)
    D = degrees
    D = D.to_dense() + 1 # transfer D from sparse tensor to normal torch tensor
    D = torch.pow(D, -0.5)
    D = D.unsqueeze(dim=1)
    gcn_diffusion_list = []
    A_gcn_feature = feature
    for i in range(order):
        A_gcn_feature = torch.mul(A_gcn_feature,D)
        A_gcn_feature = torch.spmm(A_gcn,A_gcn_feature)
        A_gcn_feature = torch.mul(A_gcn_feature,D)
        gcn_diffusion_list += [A_gcn_feature,]
    return gcn_diffusion_list

def SCT1st(sptensor,order,feature):
    '''
    sptensor = W
    '''
    degrees = torch.sparse.sum(sptensor,0)
    D = degrees
    D = D.to_dense() + 1# transfer D from sparse tensor to normal torch tensor
    D = torch.pow(D, -1)
    D = D.unsqueeze(dim=1)
    iteration = 2**(order-1)
    feature_p = feature
    for i in range(iteration):
        D_inv_x = D*feature_p
        W_D_inv_x = torch.spmm(sptensor,D_inv_x)
        feature_p = 0.5*feature_p + 0.5*W_D_inv_x
#        feature_p = torch.spmm(adj_sct,feature_p) #compute P^{2^(k-1)}
    featura_loc = feature_p
    for j in range(iteration):
        D_inv_x = D*feature_p
        W_D_inv_x = torch.spmm(sptensor,D_inv_x)
        feature_p = 0.5*feature_p + 0.5*W_D_inv_x
    feature_p = featura_loc - feature_p
    return feature_p

def scattering_diffusion(sptensor,feature):
    '''
    A_tilte,adj_p,shape(N,N)
    feature:shape(N,3) :torch.FloatTensor
    all on cuda
    '''

    h_sct1 = SCT1stv2(sptensor,1,feature)

    return h_sct1

def SCT1stv2(sptensor,order,feature):
    '''
    sptensor = W
   '''
    degrees = torch.sparse.sum(sptensor,0)
    D = degrees
    D = D.to_dense() + 1 # transfer D from sparse tensor to normal torch tensor
    D = torch.pow(D, -1)
    D = D.unsqueeze(dim=1)
    iteration = 2**order
    scale_list = list(2**i - 1 for i in range(order+1))
    feature_p = feature
    sct_diffusion_list = []
    for i in range(iteration):
        D_inv_x = D*feature_p
        W_D_inv_x = torch.spmm(sptensor,D_inv_x)
        feature_p = 0.5*feature_p + 0.5*W_D_inv_x
        if i in scale_list:
            sct_diffusion_list += [feature_p,]
    sct_feature1 = sct_diffusion_list[0]-sct_diffusion_list[1]
    return sct_feature1

