import networkx as nx

from torch_geometric.data import Data
from torch_geometric.datasets import CitationFull
from torch_geometric.utils import is_undirected, subgraph, to_networkx

from arrow_diff.metrics import get_graph_statistics


for dataset_name in ['Cora_ML', 'Cora', 'CiteSeer', 'DBLP', 'PubMed']:
    print(f'Dataset: {dataset_name}')
    dataset = CitationFull(root='./data/', name=dataset_name)

    data = dataset[0]

    print('\t Original Graph:', data)
    print(f'\t- Number of nodes: {data.num_nodes}')
    print(f'\t- Number of edges: {data.num_edges // 2}')
    print(f'\t- Number of node features: {data.num_node_features}')
    print(f'\t- Number of node classes: {dataset.num_classes}')

    undirected = is_undirected(data.edge_index, num_nodes=data.num_nodes)

    print(f'\t- Undirected graph: {undirected}')

    graph = to_networkx(data, to_undirected=undirected)

    print('\t', graph)
    connected_components = nx.connected_components(graph)

    largest_connected_component = list(max(connected_components, key=len))

    print(f'\t- Number of nodes in largest connected component: {len(largest_connected_component)}')
    print(f'\t- Number of connected components: {nx.number_connected_components(graph)}')
    print(f'\t- Average node degree: '
          f'{sum([graph.degree[i] for i in range(graph.number_of_nodes())]) / graph.number_of_nodes():.1f}')

    edge_index, _ = subgraph(largest_connected_component, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes)

    data = Data(edge_index=edge_index, num_nodes=len(largest_connected_component))

    print('\n\t Graph with largest connected component:')
    print(f'\t- Number of nodes: {data.num_nodes}')
    print(f'\t- Number of edges: {data.num_edges // 2}')

    graph = to_networkx(data, to_undirected=undirected)

    print(f'\t- Average node degree: '
          f'{sum([graph.degree[i] for i in range(graph.number_of_nodes())]) / graph.number_of_nodes():.1f}')

    (degree_assortativity, avg_clustering_coeff, global_clustering_coeff, power_law_exp, num_triangles, max_degree,
     connected_component_sizes) = get_graph_statistics(graph)

    print(f'\t- Degree assortativity: {degree_assortativity}')
    print(f'\t- Average clustering coefficient: {avg_clustering_coeff}')
    print(f'\t- Global clustering coefficient: {global_clustering_coeff}')
    print(f'\t- Power law exponent: {power_law_exp}')
    print(f'\t- Number of triangles: {num_triangles}')
    print(f'\t- Maximum node degree: {max_degree}')
    print(f'\t- Number of connected components: {len(connected_component_sizes)}')

    print('\n')
