import pickle
import numpy as np
from scipy.sparse import load_npz
import torch
from deeprobust.graph.data import Dataset
import copy
from my_utils.utils import spectral_embedding, hnsw, construct_adj,adj2laplacian,spectral_embedding_eig, SPF
from scipy.sparse.csgraph import laplacian
from scipy.sparse.linalg import eigsh
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from julia.api import Julia
from scipy.io import mmwrite

def julia_eig(l_in, num_eigs):
    jl = Julia(compiled_modules=False)
    from julia import Main
    Main.include("./my_utils/eigen.jl")
    eigenvalues, eigenvectors = Main.not_main(l_in, num_eigs)
    return eigenvalues, eigenvectors

device = 'cpu'
device = torch.device(device)
for dataset in ['cora', 'citeseer', 'chameleon', 'squirrel','pubmed']:
    ## load dataset
    if dataset in ['chameleon', 'squirrel']:
        with open(f'data/{dataset}_data.pickle', 'rb') as handle:
            data = pickle.load(handle)
        features = data["features"]
        labels = data["labels"]
        idx_train = data["idx_train"]
        idx_val = data["idx_val"]
        idx_test = data["idx_test"]
        adj_mtx = load_npz(f'data/{dataset}.npz')
    else:
        data = Dataset(root='./data/', name=dataset, setting='prognn')
        labels = data.labels
        features = data.features.todense()
        idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
        adj_mtx = data.adj


    adj_mtx = adj_mtx.asfptype()
    spec_embed = spectral_embedding_eig(adj_mtx,features,use_feature=True,adj_norm=False,eig_julia=True)
    neighs, distance = hnsw(spec_embed, k=50)
    adj_mtx,_,_ = construct_adj(neighs, distance)
    adj_mtx = SPF(adj_mtx, 4)
    L_mtx = laplacian(adj_mtx, normed=True)#.tocsr()

    #eigenvalues, eigenvectors = eigsh(L_mtx,k=40,which='SM', maxiter=500000)
    #eigenvalues = np.sort(eigenvalues)
    eigenvalues, eigenvectors = julia_eig(L_mtx, 100)
    x = np.arange(1, len(eigenvalues) + 1)
    plt.plot(x, eigenvalues.real, marker='o')
    plt.xlabel('nth Smallest Eigenvalue')
    plt.ylabel('Eigenvalue')
    plt.title('{},class:{}'.format(dataset,labels.max()+1))
    plt.grid(True)
    # Set y-axis ticks to integer values
    ax = plt.gca()
    ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
    plt.savefig('{}.png'.format(dataset), dpi=300)
    plt.clf()
