#!/opt/conda/bin/python3
import time

import numpy as np
import torch

import networkx as nx
from torch_geometric.transforms import GCNNorm
from torch_geometric.utils import from_networkx, to_dense_adj

from sct_gnn import connected_components, kernel_vectors

import matplotlib.pyplot as plt

#----------------------------------------------------------------------------------------------------------------------------------------------------

def eigen_decomposition(edge_index, edge_weight):
    aug = to_dense_adj(edge_index,edge_attr=edge_weight)[0]
    assert torch.all(torch.eq(aug,aug.T)), 'Augmented adjacency is not symmetric'

    eigs, vecs = torch.linalg.eigh(aug)
    perp = torch.isclose(eigs,torch.tensor(1,dtype=torch.float32),atol=1e-3,rtol=1e-3)
    perp_index = torch.where(perp)
    return vecs[perp_index] #jb: should be real only

#----------------------------------------------------------------------------------------------------------------------------------------------------

dfs_data = []
eig_data = []
for graph_dim in [100, 600, 700, 800, 900, 1000]:
    print(f'Graph Dim: {graph_dim}')
    for p in [1e-4, 1e-3, 1e-2, 5e-2, 1e-1, 5e-1]:
        print(f'Edge Probability: {p}')

        nx_data = nx.erdos_renyi_graph(int(graph_dim), float(p))
        data = from_networkx(nx_data)
        data = GCNNorm()(data)

        dfs_times = []
        eig_times = []
        for _ in range(25):

            # O(N+E)
            start = time.time()
            ind = connected_components(data.edge_index, data.num_nodes)
            kv = kernel_vectors(data.edge_index, edge_weight=data.edge_weight, indicators=ind)
            dfs_times.append(time.time()-start)

            # O(N^2 log(N))
            start = time.time()
            kv = eigen_decomposition(data.edge_index, data.edge_weight)
            eig_times.append(time.time()-start)

        dfs_data.append([data.num_nodes + data.edge_index.shape[1],
                np.mean(dfs_times)])

        eig_data.append([data.num_nodes + data.edge_index.shape[1],
                np.mean(eig_times)])


plt.figure(tight_layout=True)

dfs_data = sorted(dfs_data, key=lambda x: x[0])
size = [x for x,y in dfs_data]
times = [y for x,y in dfs_data]
plt.plot(size,times)

eig_data = sorted(eig_data, key=lambda x: x[0])
size = [x for x,y in eig_data]
times = [y for x,y in eig_data]
plt.plot(size,times)

plt.xscale('log')
plt.savefig('/root/workspace/out/tmp.pdf',format='pdf',bbox_inches='tight')