# Dataset index:
# Cora:1, Citeseer: 2, Pubmed: 3, Physics: 4, CS: 5, Computers: 6
# Photos: 7, Reddit: 8, Github_social: 9, Twitch.DE: 10, Twitch.FR: 11, Wiki.Croc: 12
# Wiki.squirrel: 13


from dgl.data import citation_graph as citegrh
from dgl.data import reddit
from dgl.data import gnn_benckmark as gnnbnch
from scipy import sparse
from scipy.sparse import csr_matrix
import numpy as np
import torch
from dgl import DGLGraph
import networkx as nx
from dgl import transform
from csv_json_data_loader import snap_data_loader
import pandas as pd
from scipy import sparse, stats
from numpy import inf
import random

SNAP_edge_files = ['git_web_ml/musae_git_edges.csv', 'twitch/DE/musae_DE_edges.csv', 'twitch/FR/musae_FR_edges.csv',
                   'wikipedia/musae_crocodile_edges.csv', 'wikipedia/musae_squirrel_edges.csv']
SNAP_features_files = ['git_web_ml/musae_git_features.json', 'twitch/DE/musae_DE_features.json',
                       'twitch/FR/musae_FR_features.json', 'wikipedia/musae_crocodile_features.json',
                       'wikipedia/musae_squirrel_features.json']
SNAP_label_files = ['git_web_ml/musae_git_target.csv', 'twitch/DE/musae_DE_target.csv', 'twitch/FR/musae_FR_target.csv',
                    'wikipedia/musae_crocodile_target.csv', 'wikipedia/musae_squirrel_target.csv']
SNAP_dataset_name = ['git', 'Twitch_DE', 'Twitch_FR', 'crocs', 'squirrels']


def generate_L(Ag, N):
    v = np.ones(N)
    Dv = Ag.dot(v)
    Dg = csr_matrix((Dv, (np.arange(N), np.arange(N))), shape=(N, N))
    Lg = Dg - Ag
    print("Finished building the Laplacian")
    return Lg


def load_data(choice: int):
    if choice < 4:
        if choice == 1:
            data = citegrh.load_cora()
        elif choice == 2:
            data = citegrh.load_citeseer()
        else:
            data = citegrh.load_pubmed()

        features = torch.FloatTensor(data.features)
        labels = torch.LongTensor(data.labels)
        mask = torch.BoolTensor(data.train_mask)
        test_mask = torch.BoolTensor(data.test_mask)
        val_mask = torch.BoolTensor(data.val_mask)

        g = data.graph
        Ne = g.number_of_edges()
        N = g.number_of_nodes()

        # remove self loop
        g.remove_edges_from(nx.selfloop_edges(g))
        g = DGLGraph(g)
        # g_sp, Ne_sp = generate_spare_graph(g, N, choice)

        return g, features, labels, mask, test_mask, N, Ne, val_mask

    if choice > 3 and choice < 8:
        if choice == 4:  # num_classes = 5
            data = gnnbnch.Coauthor('physics')
        elif choice == 5:  # num_classes = 15
            data = gnnbnch.Coauthor('cs')
        elif choice == 6:  # num_classes = 67
            data = gnnbnch.AmazonCoBuy('computers')
        else:  # num_classes = 8
            data = gnnbnch.AmazonCoBuy('photo')

        g = data[0]

        features = torch.FloatTensor(g.ndata['feat'])
        labels = torch.LongTensor(g.ndata['label'])

        g = transform.remove_self_loop(g)
        N = g.number_of_nodes()
        Ne = g.number_of_edges() / 2

        num_train = int(np.round(1 * N / 10))
        print("Num. train", num_train)
        num_val = int(np.round(1 * N / 5))
        print("Num. val", num_val)
        ind = np.random.permutation(N)
        ind_train = ind[0:num_train]
        ind_val = ind[num_train:num_train + num_val]
        ind_test = ind[num_train + num_val:]
        print("Num. test", N - num_train - num_val)

        mask = np.zeros(N)
        test_mask = np.zeros(N)
        val_mask = np.zeros(N)

        mask[ind_train] = 1
        test_mask[ind_test] = 1
        val_mask[ind_val] = 1

        mask = torch.BoolTensor(mask)
        test_mask = torch.BoolTensor(test_mask)
        val_mask = torch.BoolTensor(val_mask)

        # g_sp, Ne_sp = generate_spare_graph(g, N, choice)

        return g, features, labels, mask, test_mask, N, Ne, val_mask

    if choice == 8:
        data = reddit.RedditDataset()
        g = data.graph
        N = g.number_of_nodes()
        Ne = g.number_of_edges() / 2
        features = torch.FloatTensor(data.features)
        labels = torch.LongTensor(data.labels)
        mask = torch.BoolTensor(data.train_mask)
        test_mask = torch.BoolTensor(data.test_mask)
        val_mask = torch.BoolTensor(data.val_mask)

        g = transform.remove_self_loop(g)

        # g_sp, Ne_sp = generate_spare_graph(g, N, choice)

        return g, features, labels, mask, test_mask, N, Ne, val_mask

    if choice > 8:
        edge_filename = SNAP_edge_files[choice - 9]
        label_filename = SNAP_label_files[choice - 9]
        feature_filename = SNAP_features_files[choice - 9]
        filename = SNAP_dataset_name[choice - 9]

        data = snap_data_loader(edge_filename, label_filename, feature_filename, filename)

        N = data.N
        Ne = data.Ne
        features = torch.FloatTensor(data.X)
        labels = torch.LongTensor(data.L[:, 2].astype(int))

        num_train = int(np.round(1 * N / 10))
        print("Num. train", num_train)
        num_val = int(np.round(1 * N / 5))
        print("Num. val", num_val)
        ind = np.random.permutation(N)
        ind_train = ind[0:num_train]
        ind_val = ind[num_train:num_train + num_val]
        ind_test = ind[num_train + num_val:]
        print("Num. test", N - num_train - num_val)

        mask = np.zeros(N)
        test_mask = np.zeros(N)
        val_mask = np.zeros(N)

        mask[ind_train] = 1
        test_mask[ind_test] = 1
        val_mask[ind_val] = 1

        mask = torch.BoolTensor(mask)
        test_mask = torch.BoolTensor(test_mask)
        val_mask = torch.BoolTensor(val_mask)

        g = DGLGraph()
        g.from_scipy_sparse_matrix(csr_matrix(data.A))
        g = transform.remove_self_loop(g)

        # g_sp, Ne_sp = generate_spare_graph(g,N, choice)

        return g, features, labels, mask, test_mask, N, Ne, val_mask


def generate_spare_graph(g, N, choice, epsilon):
    # sparsification
    Ag = g.adjacency_matrix_scipy(return_edge_ids=False)
    Lg = generate_L(Ag, N)

    print("Sparsifying the graph: ")
    Wsp, Ne_sp = graph_sparsify(Lg, epsilon, choice)
    print("Finished sparsifying the graph: ")

    g_sp = DGLGraph()
    g_sp.from_scipy_sparse_matrix(Wsp)

    # add self loop
    g.add_edges(g.nodes(), g.nodes())
    g_sp.add_edges(g_sp.nodes(), g_sp.nodes())
    return g_sp, Ne_sp


def generate_randomly_spare_graph(g, N, Ne, epsilon, rand_seed=42):
    # sparsification
    Ag = g.adjacency_matrix_scipy(return_edge_ids=False)

    print("Sparsifying the graph: ")
    # Number of edges to keep
    C0 = 1 / 30.
    # Rudelson and Vershynin, 2007, Thm. 3.1
    C = 4 * C0
    q = round(N * np.log(N) * 9 * C ** 2 / (epsilon ** 2))

    start_nodes, end_nodes, weights = sparse.find(sparse.tril(Ag))

    random.seed(rand_seed)
    results = np.random.choice(np.shape(weights)[0], int(q))
    #    results = np.random.choice(np.arange(Ne), int(q))
    spin_counts = stats.itemfreq(results).astype(int)

    per_spin_weights = weights * Ne / q
    per_spin_weights[per_spin_weights == inf] = 0

    counts = np.zeros(np.shape(weights)[0])
    counts[spin_counts[:, 0]] = spin_counts[:, 1]
    new_weights = counts * per_spin_weights

    sparserW = sparse.csc_matrix((np.squeeze(new_weights), (start_nodes, end_nodes)),
                                 shape=(N, N))
    print("Number of edges after sparsification: ", sparserW.count_nonzero())

    sparserW = sparserW + sparserW.T

    Wsp, Ne_sp = sparserW, q

    print("Finished sparsifying the graph: ")

    g_sp = DGLGraph()
    g_sp.from_scipy_sparse_matrix(Wsp)

    # add self loop
    g.add_edges(g.nodes(), g.nodes())
    g_sp.add_edges(g_sp.nodes(), g_sp.nodes())
    return g_sp, Ne_sp


def graph_sparsify(Lg, epsilon, choice):
    filenames = ['V_Cora.csv', 'citeseer_Reff.txt', 'V_Pubmed.csv', 'V_Phy.csv', 'V_CS.csv',
                 'Amazon_computers_Reff.txt',
                 'Amazon_photo_Reff.txt', 'V_R_eff_Reddit.csv', 'V_git.csv', 'V_twitch_DE.csv', 'V_twitch_FR.csv',
                 'V_wiki_crocs.csv', 'V_wiki_squirrels.csv']

    filename = filenames[choice - 1]
    print("Computing resistances for ", filename)

    N = np.size(Lg, 0)
    Dv = Lg.diagonal()
    Dg = csr_matrix((Dv, (np.arange(N), np.arange(N))), shape=(N, N))
    W = Dg - Lg
    read_V = [1, 3, 4, 5, 8, 9, 10, 11, 12, 13]

    if choice in read_V:
        print("Reading V matrix:..")
        V_frame = pd.read_csv(filename, header=None)
        V = V_frame.to_numpy()
        print("Computing edge resistances:... ")

        resistance_distances = compute_reff(W, V)
        print("Finished loading resistances:")

    else:
        resistance_distances = np.loadtxt(filename)

    start_nodes, end_nodes, weights = sparse.find(sparse.tril(W))

    # Calculate the new weights.
    weights = np.maximum(0, weights)
    if choice in read_V:
        Re = np.maximum(0, resistance_distances[start_nodes, end_nodes].toarray())
    else:
        Re = np.maximum(0, resistance_distances[start_nodes, end_nodes])
    Pe = weights * Re
    Pe = Pe / np.sum(Pe)
    Pe = np.squeeze(Pe)

    # Rudelson, 1996 Random Vectors in the Isotropic Position
    # (too hard to figure out actual C0)
    C0 = 1 / 30.
    # Rudelson and Vershynin, 2007, Thm. 3.1
    C = 4 * C0
    q = round(N * np.log(N) * 9 * C ** 2 / (epsilon ** 2))

    #        results = stats.rv_discrete(values=(np.arange(np.shape(Pe)[0]), Pe)).rvs(size=int(q))
    random.seed(42)
    results = np.random.choice(np.arange(np.shape(Pe)[0]), int(q), p=list(Pe))
    spin_counts = stats.itemfreq(results).astype(int)

    per_spin_weights = weights / (q * Pe)
    per_spin_weights[per_spin_weights == inf] = 0

    counts = np.zeros(np.shape(weights)[0])
    counts[spin_counts[:, 0]] = spin_counts[:, 1]
    new_weights = counts * per_spin_weights

    sparserW = sparse.csc_matrix((np.squeeze(new_weights), (start_nodes, end_nodes)),
                                 shape=(N, N))
    sparserW = sparserW + sparserW.T

    return sparserW, np.count_nonzero(new_weights)


def compute_reff(W, V):
    start_nodes, end_nodes, weights = sparse.find(sparse.tril(W))
    n = np.shape(W)[0]
    Reff = sparse.lil_matrix((n, n))
    for orig, end in zip(start_nodes, end_nodes):
        Reff[orig, end] = np.linalg.norm(V[orig, :] - V[end, :]) ** 2
    return Reff


def generate_reff(g, N, choice):
    filenames = ['V_Cora.csv', 'citeseer_Reff.txt', 'V_Pubmed.csv', 'V_Phy.csv', 'V_CS.csv',
                 'Amazon_computers_Reff.txt',
                 'Amazon_Photo_Reff.txt', 'V_R_eff_Reddit.csv', 'V_git.csv', 'V_twitch_DE.csv', 'V_twitch_FR.csv',
                 'V_wiki_crocs.csv', 'V_wiki_squirrels.csv', 'V_reddit_subset.csv']

    filename = filenames[choice - 1]

    # sparsification
    W = g.adjacency_matrix_scipy(return_edge_ids=False)
    read_V = [1, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14]

    if choice in read_V:
        print("Reading V matrix:..")
        V_frame = pd.read_csv(filename, header=None)
        V = V_frame.to_numpy()
        print("Computing edge resistances:... ")
        resistance_distances = compute_reff(W, V)
        print("Finished loading resistances:")

    else:
        resistance_distances = np.loadtxt(filename)
    return resistance_distances
