#!/opt/conda/bin/python3
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import torch

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import from_networkx, to_dense_adj
import torch_geometric.transforms as T

import matplotlib.pyplot as plt

from sct_gnn import SCT, kernel_vectors
from sct_gnn.smoothness import NodeFeatureSmoothness

graph_dim = 4
p = 0.2


# nx_data = nx.erdos_renyi_graph(int(graph_dim), float(p))
# nx.draw(nx_data, with_labels=True)
# plt.savefig('/root/workspace/out/tmp.pdf')
# data = from_networkx(nx_data)
data_name = 'cora'
print(f'Planetoid: {data_name}')
transform = T.Compose([T.RemoveIsolatedNodes(), T.GCNNorm()])
dataset = Planetoid(
        root='/root/workspace/data/'+data_name,
        name=data_name,
        split='public',
        transform=transform,
)
data = dataset[0]


adj = to_dense_adj(data.edge_index, edge_attr=data.edge_weight)[0]
print(adj)
ew, ev = torch.linalg.eig(adj)
ew, ev = ew.real, ev.real
idx = torch.where(torch.isclose(ew,torch.ones(1)))
ker_vecs = ev.T[idx]
print(len(ker_vecs))
for kv in ker_vecs:
    kv = kv/torch.norm(kv)
    # print(torch.max((adj@kv-kv).abs()))

print('my')
ker_vecs = kernel_vectors(data.edge_index, edge_weight=data.edge_weight)
print(len(ker_vecs))
for kv in ker_vecs:
    max_ = (torch.max((adj@kv-kv).abs()))
    if max_>1e-7:
        print(max_)
        # print(kv)

exit()

x = torch.rand(4,2)
mod = GCNConv(2,1, normalize=False)
bias = SCT(2,1, cached=True)

z = mod(x, data.edge_index, edge_weight=data.edge_weight)
y = bias(x, data.edge_index, edge_weight=data.edge_weight)
# print('z',z)
# print('y',y)
# print(z+y)

# print(bias._cached_ker)

measure = NodeFeatureSmoothness()

print(measure(z, data.edge_index, edge_weight=data.edge_weight))
print(measure(z+y, data.edge_index, edge_weight=data.edge_weight))
