# default link in grakel has changed
# download dataset from the new link

import requests
import sys
import os
import networkx as nx
import numpy as np
from grakel import Graph
from grakel.utils import graph_from_networkx
from grakel.datasets import fetch_dataset

def load_data(name):
    # request the dataset from the URL and save it locally
    url = "https://www.chrsmrrs.com/graphkerneldatasets/"+name+".zip"
    path = os.path.abspath('../data/benchmark_data/')
    filename = os.path.join(path, name+".zip")
    with requests.get(url, stream=True) as r:
        r.raise_for_status()  # Raise an error on bad status
        with open(filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)

    dataset = fetch_dataset(name, data_home=path, as_graphs=True)
    return dataset

def init_weights_labels(graphs):
    """
    Initialize weights and labels of the graphs in the dataset.
    """

    for G in graphs:
        # add random weights
        A = G.get_adjacency_matrix().astype(np.float64)
        nonzero_mask = A != 0.0
        weights = np.random.random(size=A.shape)
        A[nonzero_mask] = weights[nonzero_mask]
        G.adjacency_matrix = A

        # add trivial labels
        G.n = A.shape[0]
        nodes = list(range(0, G.n))
        G.vertices = nodes
        #nodes = list(range(1,G.n+1)) # some kernels require the label start from 1
        G.node_labels = dict(zip(nodes, G.n*['0']))
        G.index_node_labels = G.node_labels
        edge_labels = {}
        for i in nodes:
            for j in nodes:
                if G.adjacency_matrix[i, j] != 0:
                    edge_labels[(i, j)] = G.adjacency_matrix[i,j]
        G.edge_labels = edge_labels
        G.index_edge_labels = edge_labels

        # edge dictionary in grakel
        edge_dict = {}
        for i in nodes:
            # dict for neighbours
            neighbours = {}
            for j in nodes:
                if G.adjacency_matrix[i, j] != 0:
                    neighbours[j] = G.adjacency_matrix[i, j]
            edge_dict[i] = neighbours
        G.edge_dictionary = edge_dict

def init_labels(graphs):
    """
    Initialize weights and labels of the graphs in the dataset.
    """

    for G in graphs:
        # add random weights
        A = G.get_adjacency_matrix().astype(np.float64)

        # add trivial labels
        G.n = A.shape[0]
        nodes = list(range(0, G.n))
        G.vertices = nodes
        #nodes = list(range(1,G.n+1)) # some kernels require the label start from 1
        G.node_labels = dict(zip(nodes, G.n*['0']))
        G.index_node_labels = G.node_labels
        edge_labels = {}
        for i in nodes:
            for j in nodes:
                if G.adjacency_matrix[i, j] != 0:
                    edge_labels[(i, j)] = G.adjacency_matrix[i,j]
        G.edge_labels = edge_labels
        G.index_edge_labels = edge_labels

        # edge dictionary in grakel
        edge_dict = {}
        for i in nodes:
            # dict for neighbours
            neighbours = {}
            for j in nodes:
                if G.adjacency_matrix[i, j] != 0:
                    neighbours[j] = G.adjacency_matrix[i, j]
            edge_dict[i] = neighbours
        G.edge_dictionary = edge_dict

def add_labels(graphs):
    """
    Add labels to nodes and edges of the graphs in the dataset.
    """

    for G in graphs:
        G.n = G.get_adjacency_matrix().shape[0]
        nodes = list(range(0, G.n))
        nodes_for_kernel = list(range(1,G.n+1)) # in some kernelx the node key start from 1!!!
        if G.node_labels == {}:
            G.node_labels = dict(zip(nodes_for_kernel, G.n*[0]))
        if G.edge_labels == {}:
            edge_labels = {}
            for i in nodes:
                for j in nodes:
                    if G.adjacency_matrix[i, j] != 0:
                        edge_labels[(i, j)] = 0
            G.edge_labels = edge_labels

def show_stats(graphs):
    """
    Show statistics of the dataset.
    """
    
    num_graphs = len(graphs)
    num_nodes_list = []
    num_edges_list = []
    num_components_list = []
    genus_list = []

    for graph in graphs:
        A = graph.get_adjacency_matrix()
        G = nx.from_numpy_array(A)
        num_nodes = G.number_of_nodes()
        num_edges = G.number_of_edges()
        num_components = nx.number_connected_components(G)
        genus = num_edges - num_nodes + num_components
        num_nodes_list.append(num_nodes)
        num_edges_list.append(num_edges)
        num_components_list.append(num_components)
        genus_list.append(genus)

    # dictionary to store the statistics
    stats = {
        'num_graphs': num_graphs,
        'num_nodes': num_nodes_list,
        'num_edges': num_edges_list,
        'num_components': num_components_list,
        'genus': genus_list
    }

    return stats