import torch
import torch.nn as nn
import torch.nn.functional as F
from pygsp import graphs, filters, reduction, utils
from dgl import DGLGraph
from dgl import  transform
from dgl.data import citation_graph as citegrh
from dgl.data import reddit
from dgl.data import gnn_benckmark as gnnbnch
import time
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import scipy
from scipy import sparse, stats
import GAT_defn
from csv_json_data_loader import snap_data_loader
from scipy.sparse import  csr_matrix

# Experiment setup to study the computation time and accuracy when sparsified graphs are used in GATs

# Datasets: Cora, Cora_full citeseer, pubmed.

# Cora_full does not come with pre-defined partition into train, valid and test sets. Need to do it manually
# Also need to remove self edges, and then add them back again for all nodes.

SNAP_edge_files = ['git_web_ml/musae_git_edges.csv', 'twitch/DE/musae_DE_edges.csv','twitch/FR/musae_FR_edges.csv','wikipedia/crocodile/musae_crocodile_edges.csv','wikipedia/squirrel/musae_squirrel_edges.csv']
SNAP_features_files = ['git_web_ml/musae_git_features.json', 'twitch/DE/musae_DE_features.json','twitch/FR/musae_FR_features.json','wikipedia/crocodile/musae_crocodile_features.json','wikipedia/squirrel/musae_squirrel_features.json']
SNAP_label_files = ['git_web_ml/musae_git_target.csv', 'twitch/DE/musae_DE_target.csv','twitch/FR/musae_FR_target.csv','wikipedia/crocodile/musae_crocodile_target.csv','wikipedia/squirrel/musae_squirrel_target.csv']
SNAP_dataset_name = ['git','Twitch_DE','Twitch_FR','crocs','squirrels']


def load_cora_data(choice):

    if choice==1:
        data = citegrh.load_cora()
    elif choice==2:
        data = citegrh.load_citeseer()
    else:
        data = citegrh.load_pubmed()


    g = data.graph
    # add self loop
    g.remove_edges_from(nx.selfloop_edges(g))
    g = DGLGraph(g)
    # Ne = g.number_of_edges() / 2

    # sparsification
    Ag = g.adjacency_matrix_scipy(return_edge_ids=False).tocsc()

    return Ag

def load_benchmark_data(choice):
    if choice == 4: # num_classes = 5
        data = gnnbnch.Coauthor('physics')
    elif choice == 5:# num_classes = 15
        data = gnnbnch.Coauthor('cs')
    elif choice == 6:# num_classes = 8
        data = gnnbnch.AmazonCoBuy('photo')
    else:# num_classes = 10
        data = gnnbnch.AmazonCoBuy('computers')

    g = data[0]
    g = transform.remove_self_loop(g)
    Ag = g.adjacency_matrix_scipy(return_edge_ids=False).tocsc()

    return Ag

def load_reddit():
    data = reddit.RedditDataset()
    g = data.graph
    Ag = g.adjacency_matrix_scipy(return_edge_ids=False).tocsc()

    return Ag

filenames1 = ['Cora','Citeseer','Pubmed']
filenames2 = ['Phy','CS','Photo','Computers']
filenames3 = ['Reddit']
filenames4=['git']


for choice in range(5,7):

    if choice<4:
        filename = filenames1[choice-1]
        print("Loading "+filename)
        Ag = load_cora_data(choice)
    elif choice>3 and choice < 8:
        filename=filenames2[choice-4]
        Ag = load_benchmark_data(choice)
    else:
        filename = 'Reddit_sparse'
        Ag = load_reddit()

    print("Saving data now.. ")
    sparse.save_npz(filename,Ag)

for choice in range(9,14):

    edge_filename = './'+SNAP_edge_files[choice-9]
    label_filename = './'+SNAP_label_files[choice-9]
    feature_filename = './'+SNAP_features_files[choice-9]
    filename=filenames4[choice-9]

    print(filename)

    data = snap_data_loader(edge_filename,label_filename,feature_filename,filename)
    A=data.A
    A=csr_matrix(A)


    print("Saving data now.. ")
    sparse.save_npz(filename, A)


