import numpy as np
import torch
import dgl
import torch.nn.functional as F
import argparse
from sklearn.metrics import f1_score
from gat import GAT
from dgl.data.ppi import LegacyPPIDataset
from torch.utils.data import DataLoader
import time

from dgl.data import citation_graph as citegrh
from dgl.data import reddit
from dgl.data import gnn_benckmark as gnnbnch
from scipy import sparse
from scipy.sparse import csr_matrix

from dgl import DGLGraph
import networkx as nx
from dgl import  transform
from csv_json_data_loader import snap_data_loader
import pandas as pd
import scipy
from scipy import sparse, stats
from numpy import inf
import random


def sparsifyPPI(dataset, epsilon):
    num = len(dataset)
    for i in range(num):
        N = dataset.train_graphs[i].number_of_nodes()
        Ne = dataset.train_graphs[i].number_of_edges()
        g = dataset.train_graphs[i]
#        g_sp, Ne_sp = generate_sparse_graph(g, N, i, epsilon)
        
        g_sp, Ne_sp = generate_randomly_sparse_graph(g, N, Ne, epsilon)
        
        dataset.train_graphs[i] = g_sp

    return dataset

def generate_L(Ag, N):
    v = np.ones(N)
    Dv = Ag.dot(v)
    Dg = csr_matrix((Dv, (np.arange(N), np.arange(N))), shape=(N, N))
    Lg = Dg - Ag
    print("Finished building the Laplacian")
    return Lg

  
  
def generate_sparse_graph(g,N,i,epsilon):
    # sparsification
    Ag = g.adjacency_matrix_scipy(return_edge_ids=False)
    Lg = generate_L(Ag, N)

    print("Sparsifying the graph: ")
    Wsp, Ne_sp = graph_sparsify(Lg, epsilon,i)
    print("Finished sparsifying the graph: ")

    g_sp = DGLGraph()
    g_sp.from_scipy_sparse_matrix(Wsp)

    # add self loop
#    g.add_edges(g.nodes(), g.nodes())
    g_sp.add_edges(g_sp.nodes(), g_sp.nodes())
    return g_sp, Ne_sp

def graph_sparsify(Lg, epsilon, i):

    filename='V_PPIG'+str(i)+'.csv'

    N = np.size(Lg,0)
    Dv = Lg.diagonal()
    Dg = csr_matrix((Dv, (np.arange(N), np.arange(N))), shape=(N, N))
    W = Dg - Lg

    print("Reading V matrix:..")
    V_frame = pd.read_csv(filename, header=None)
    V = V_frame.to_numpy()
    print("Computing edge resistances:... ")
    resistance_distances = compute_reff(W, V)
    print("Finished loading resistances:")


    start_nodes, end_nodes, weights = sparse.find(sparse.tril(W))

    # Calculate the new weights.
    weights = np.maximum(0, weights)
    Re = np.maximum(0, resistance_distances[start_nodes, end_nodes].toarray())
    Pe = weights * Re
    Pe = Pe / np.sum(Pe)
    Pe = np.squeeze(Pe)


    # Rudelson, 1996 Random Vectors in the Isotropic Position
    # (too hard to figure out actual C0)
    C0 = 1 / 30.
    # Rudelson and Vershynin, 2007, Thm. 3.1
    C = 4 * C0
    q = round(N * np.log(N) * 9 * C ** 2 / (epsilon ** 2))

    #        results = stats.rv_discrete(values=(np.arange(np.shape(Pe)[0]), Pe)).rvs(size=int(q))
    results = np.random.choice(np.arange(np.shape(Pe)[0]), int(q), p=list(Pe))
    spin_counts = stats.itemfreq(results).astype(int)

    per_spin_weights = weights / (q * Pe)
    per_spin_weights[per_spin_weights == inf] = 0

    counts = np.zeros(np.shape(weights)[0])
    counts[spin_counts[:, 0]] = spin_counts[:, 1]
    new_weights = counts * per_spin_weights

    sparserW = sparse.csc_matrix((np.squeeze(new_weights), (start_nodes, end_nodes)),
                                 shape=(N, N))
    sparserW = sparserW + sparserW.T

    return sparserW, np.count_nonzero(new_weights)


def compute_reff(W, V):

    start_nodes, end_nodes, weights = sparse.find(sparse.tril(W))
    n = np.shape(W)[0]
    Reff = sparse.lil_matrix((n,n))
    for orig, end in zip(start_nodes, end_nodes):
        Reff[orig,end] = np.linalg.norm(V[orig,:] - V[end,:])**2
    return Reff

def generate_randomly_sparse_graph(g,N,Ne, epsilon, rand_seed=42):
    # sparsification
    Ag = g.adjacency_matrix_scipy(return_edge_ids=False)
    
    
    print("Sparsifying the graph: ")
    # Number of edges to keep
    C0 = 1 / 30.
    # Rudelson and Vershynin, 2007, Thm. 3.1
    C = 4 * C0
    q = round(N * np.log(N) * 9 * C ** 2 / (epsilon ** 2))
    
    start_nodes, end_nodes, weights = sparse.find(sparse.tril(Ag))   
    print("Ne, length(weights):", Ne, np.shape(weights) )
    
    random.seed(rand_seed)
    results = np.random.choice(np.shape(weights)[0], int(q))
    spin_counts = stats.itemfreq(results).astype(int)

    per_spin_weights = weights*Ne/q
    per_spin_weights[per_spin_weights == inf] = 0

    counts = np.zeros(np.shape(weights)[0])
    counts[spin_counts[:, 0]] = spin_counts[:, 1]
    new_weights = counts * per_spin_weights

    sparserW = sparse.csc_matrix((np.squeeze(new_weights), (start_nodes, end_nodes)),
                                 shape=(N, N))
    print("Number of edges after sparsification: ", sparserW.count_nonzero())
    
    sparserW = sparserW + sparserW.T

    
    Wsp, Ne_sp = sparserW, q
    
    print("Finished sparsifying the graph: ")

    g_sp = DGLGraph()
    g_sp.from_scipy_sparse_matrix(Wsp)

    # add self loop
#    g.add_edges(g.nodes(), g.nodes())
    g_sp.add_edges(g_sp.nodes(), g_sp.nodes())
    return g_sp,np.count_nonzero(new_weights)
