import numpy as np
import scipy.sparse as sp
from scipy.sparse.linalg import eigsh
import pickle as pkl
import os

# forked from https://github.com/DropEdge/DropEdge

def get_diag(dataset, kwargs):
    datapath = kwargs['datapath']
    try:
        if dataset == 'pubmed':
            adj_diag = pkl.load(open(os.path.join(datapath, 'pubmedD.p'), 'rb'))
        elif dataset == 'cora':
            adj_diag = pkl.load(open(os.path.join(datapath, 'coraD.p'), 'rb'))
        elif dataset == 'citeseer':
            adj_diag = pkl.load(open(os.path.join(datapath, 'citeseerD.p'), 'rb'))
    except FileNotFoundError:
        print('The Laplacian must be diagonalized...')
    return adj_diag

def normalized_laplacian(adj):
   adj = sp.coo_matrix(adj)
   row_sum = np.array(adj.sum(1))
   d_inv_sqrt = np.power(row_sum, -0.5).flatten()
   d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
   d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
   return (sp.eye(adj.shape[0]) - d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt)).tocoo()


def laplacian(adj):
   adj = sp.coo_matrix(adj)
   row_sum = np.array(adj.sum(1)).flatten()
   d_mat = sp.diags(row_sum)
   return (d_mat - adj).tocoo()


def gcn(adj):
   adj = sp.coo_matrix(adj)
   row_sum = np.array(adj.sum(1))
   d_inv_sqrt = np.power(row_sum, -0.5).flatten()
   d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
   d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
   return (sp.eye(adj.shape[0]) + d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt)).tocoo()

def order_2(adj):
    _gcn = gcn(adj)
    return _gcn.dot(_gcn) / 2

def gcn_new_self(kwargs):
    def _aux(adj):
        assert 'gamma' in kwargs, "self_coef should be given in args."
        self_coef = kwargs['gamma']
        adj = sp.coo_matrix(adj)
        row_sum = np.array(adj.sum(1))
        d_inv_sqrt = np.power(row_sum, -0.5).flatten()
        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
        d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
        return (self_coef * sp.eye(adj.shape[0]) + d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt)).tocoo()
    return _aux

def gcn_new_self_norm(kwargs):
    def _aux(adj):
        assert 'gamma' in kwargs, "self_coef should be given in args."
        self_coef = kwargs['gamma']
        adj = sp.coo_matrix(adj)
        row_sum = np.array(adj.sum(1))
        d_inv_sqrt = np.power(row_sum, -0.5).flatten()
        d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
        d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
        return ((self_coef * sp.eye(adj.shape[0]) + d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt)) * 2 / (self_coef+1)).tocoo()
    return _aux

def gcn_new_self_with_L(kwargs):
    def _aux(adj):
        adj_gamma = gcn_new_self(kwargs)(adj)
        adj_diag = eigsh(adj_gamma.toarray(), adj.shape[0])
        n = kwargs['nfreq']
        nedges = adj_diag[1][:, -n:].dot(sp.diags(adj_diag[0][-n:]).toarray()).dot(adj_diag[1][:, -n:].transpose())
        return nedges
    return _aux

def proj_positive(adj):
    return np.vectorize(lambda x: max(x, 0))(adj)

def MLP_with_HF(adj):
    return sp.eye(adj.shape[0])

def gcn_withoutHfreq(kwargs):
    def gcn_withoutHfreq_aux(adj):
        n, projection, dataset = kwargs['nfreq'], kwargs['projection'], kwargs['dataset']
        adj_diag = get_diag(dataset, kwargs)
        assert adj_diag is not None, 'not ok'
        nedges = adj_diag[1][:, :-n].dot(sp.diags(adj_diag[0][:-n]).toarray()).dot(adj_diag[1][:, :-n].transpose())
        if projection:
            return proj_positive(nedges)
        return (nedges)
    return gcn_withoutHfreq_aux

def test_norm(adj):
    gcn_adj = gcn(adj)
    adj_diag = eigsh(gcn_adj.toarray(), adj.shape[0])
    nedges = adj_diag[1][:, 70:120].dot(sp.diags(adj_diag[0][70:120]).toarray()).dot(adj_diag[1][:, 70:120].transpose())
    return (gcn_adj - nedges)

def gcn_withoutLfreq(n=0, projection=False, pubmed=False, gamma=0):
    def gcn_withoutLfreq_aux(adj):
        adj_diag = get_diag(dataset, kwargs)
        nedges = adj_diag[1][:, :n].dot(sp.diags(adj_diag[0][:n]).toarray()).dot(adj_diag[1][:, :n].transpose())
        if projection:
            return proj_positive(gcn_adj - nedges)
        return (gcn_adj - nedges)
    return gcn_withoutLfreq_aux

def gcn_withLfreq(kwargs):
    def gcn_withoutLfreq_aux(adj):
        n, projection, dataset = kwargs['nfreq'], kwargs['projection'], kwargs['dataset']
        adj_diag = get_diag(dataset, kwargs)
        if adj_diag is None:
            gcn_adj = gcn(adj)
            adj_diag = eigsh(gcn_adj.toarray(), adj.shape[0])
        nedges = adj_diag[1][:, -n:].dot(sp.diags(adj_diag[0][-n:]).toarray()).dot(adj_diag[1][:, -n:].transpose())
        if projection:
            return proj_positive(nedges)
        return nedges
    return gcn_withoutLfreq_aux

def aug_normalized_adjacency(adj):
   adj = adj + sp.eye(adj.shape[0])
   adj = sp.coo_matrix(adj)
   row_sum = np.array(adj.sum(1))
   d_inv_sqrt = np.power(row_sum, -0.5).flatten()
   d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
   d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
   return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt).tocoo()

def bingge_norm_adjacency(adj):
   adj = adj + sp.eye(adj.shape[0])
   adj = sp.coo_matrix(adj)
   row_sum = np.array(adj.sum(1))
   d_inv_sqrt = np.power(row_sum, -0.5).flatten()
   d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
   d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
   return (d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt) +  sp.eye(adj.shape[0])).tocoo()

def normalized_adjacency(adj):
   adj = sp.coo_matrix(adj)
   row_sum = np.array(adj.sum(1))
   d_inv_sqrt = np.power(row_sum, -0.5).flatten()
   d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
   d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
   return (d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt)).tocoo()

def random_walk_laplacian(adj):
   adj = sp.coo_matrix(adj)
   row_sum = np.array(adj.sum(1))
   d_inv = np.power(row_sum, -1.0).flatten()
   d_mat = sp.diags(d_inv)
   return (sp.eye(adj.shape[0]) - d_mat.dot(adj)).tocoo()


def aug_random_walk(adj):
   adj = adj + sp.eye(adj.shape[0])
   adj = sp.coo_matrix(adj)
   row_sum = np.array(adj.sum(1))
   d_inv = np.power(row_sum, -1.0).flatten()
   d_mat = sp.diags(d_inv)
   return (d_mat.dot(adj)).tocoo()

def random_walk(adj):
   adj = sp.coo_matrix(adj)
   row_sum = np.array(adj.sum(1))
   d_inv = np.power(row_sum, -1.0).flatten()
   d_mat = sp.diags(d_inv)
   return d_mat.dot(adj).tocoo()

def no_norm(adj):
   adj = sp.coo_matrix(adj)
   return adj


def i_norm(adj):
    adj = adj + sp.eye(adj.shape[0])
    adj = sp.coo_matrix(adj)
    return adj
  
def fetch_normalization(type, kwargs={}):
    switcher = {
       'NormLap': normalized_laplacian,  # A' = I - D^-1/2 * A * D^-1/2
       'Lap': laplacian,  # A' = D - A
       'RWalkLap': random_walk_laplacian,  # A' = I - D^-1 * A
       'FirstOrderGCN': gcn,   # A' = I + D^-1/2 * A * D^-1/2
       'AugNormAdj': aug_normalized_adjacency,  # A' = (D + I)^-1/2 * ( A + I ) * (D + I)^-1/2
       'BingGeNormAdj': bingge_norm_adjacency, # A' = I + (D + I)^-1/2 * (A + I) * (D + I)^-1/2
       'NormAdj': normalized_adjacency,  # D^-1/2 * A * D^-1/2
       'RWalk': random_walk,  # A' = D^-1*A
       'AugRWalk': aug_random_walk,  # A' = (D + I)^-1*(A + I)
       'NoNorm': no_norm, # A' = A
       'INorm': i_norm,  # A' = A + I
       'gcnNoH': gcn_withoutHfreq(kwargs),
       'gcnNoL': gcn_withoutLfreq(kwargs),
       'gcnWithL': gcn_withLfreq(kwargs),
        'test_norm': test_norm,
        'gcn_new_self': gcn_new_self(kwargs),
        'gcn_new_self_norm': gcn_new_self_norm(kwargs),
        'gcn_self_H': gcn_new_self_with_L(kwargs),
        'MLP': MLP_with_HF,
        'order2': order_2,
    }
    func = switcher.get(type, lambda: "Invalid normalization technique.")
    return func

def row_normalize(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

