import os
import torch
import numpy as np


def eig(sym_mat):
    # (sorted) eigenvectors with numpy
    EigVal, EigVec = np.linalg.eigh(sym_mat)

    # for eigval, take abs because numpy sometimes computes the first eigenvalue approaching 0 from the negative
    eigvec = torch.from_numpy(EigVec).float()  # [N, N (channels)]
    eigval = torch.from_numpy(np.sort(np.abs(np.real(EigVal)))).float()  # [N (channels),]
    return eigvec, eigval  # [N, N (channels)]  [N (channels),]


def preprocess_TokenGT(adj):
    """
    Graph positional encoding v/ Laplacian eigenvectors
    https://github.com/DevinKreuzer/SAN/blob/main/data/molecules.py
    """
    # Laplacian
    A = np.asarray(adj.todense())
    in_degree = A.sum(axis=1)
    number_of_nodes = A.shape[0]
    N = np.diag(np.power(in_degree.clip(1), -0.5))
    L = np.eye(number_of_nodes) - N @ A @ N
    
    edge_index = torch.from_numpy(np.stack(A.nonzero()))

    eigvec, eigval = eig(L)
    return edge_index, eigvec, eigval  # [N, N (channels)]  [N (channels),]


def add_hops(adj, features, K=0):

    D = features.shape[-1]
    N = adj.shape[0]
    feat_hops = torch.zeros(N, K+1, D)
   
    print(adj.shape, features.shape)
    feat_hops[:, 0, :] = features
    
    if K==0:
        return feat_hops
    
    for k_idx in range(1,K+1):
        if adj.is_sparse:
            feat_hops[:,k_idx,:] = torch.spmm(adj, feat_hops[:,k_idx-1,:])
        else:
            feat_hops[:,k_idx,:] = torch.mm(adj, feat_hops[:,k_idx-1,:])
            
    return feat_hops


def postprocess_fn(args, data, datapath='./data'):
    edge_index_filename = os.path.join(datapath, f'{args.dataset}/edge_index.pt')
    lap_eigvec_filename = os.path.join(datapath, f'{args.dataset}/lap_eigvec.pt')
    
    if os.path.exists(edge_index_filename) and os.path.exists(lap_eigvec_filename):
        data['edge_index'] = torch.load(edge_index_filename, map_location=torch.device('cpu'))
        data['lap_eigvec'] = torch.load(lap_eigvec_filename, map_location=torch.device('cpu'))
    else:
        edge_index, eigvec, eigval = preprocess_TokenGT(data['adj_train'])
        torch.save(eigvec, lap_eigvec_filename)
        torch.save(edge_index, edge_index_filename)
        data['lap_eigvec'] = eigvec
        data['edge_index'] = edge_index
        
    data['features'] = add_hops(
        data['adj_train_norm'], 
        data['features'], 
        K=args.hops
    )
    data['lap_eigvec'] = add_hops(
        data['adj_train_norm'], 
        torch.cat((data['lap_eigvec'][:, :args.lap_k], data['lap_eigvec'][:, -args.lap_k:]), dim=1),
        K=args.hops
    )
    return data

