from warnings import warn

import torch as pt
from peagang.utils.utils import ensure_tensor
from DominantSparseEigenAD.symeig import DominantSymeig


def stable_sym_eigen(adj, eigenvectors=False):
    return pt.symeig(adj + pt.diag(pt.randn(adj.shape[-1], device=adj.device)), eigenvectors=eigenvectors)



def our_dom_symeig(x, k, end_offset=1):
    """
    Sice the pytorch builtin gets ALL, we need to slice
    :param x:
    :param k:
    :return:
    """
    x=ensure_tensor(x)
    vals,vecs= stable_sym_eigen(x, eigenvectors=True)

    """
    NxN
    k-1x N 
    
    Laplacian 2nd to the kth eigenvector
    
    On a graph with 2 communities: 1st vector constant, 2nd should have positive and negative entries
    
    """
    N=vals.shape[-1]
    idx=pt.arange(N - k, N - end_offset)
    with pt.no_grad():
        _,counts=pt.unique(vals,return_counts=True,dim=-1)
        if (counts>1).any().item():
            warn(f"Found non-unique eigen-values, this might cause NaNs in backwards")

    return [vals[...,idx],vecs[...,idx]]


def approx_dom_symeig(x,k):
    """
    Since DominantSymeig.apply seems to return only a single val,vec pair, we need to loop
    :param x:
    :param k:
    :return:
    """
    vals,vecs= [],[]
    for val,vec in (DominantSymeig.apply(x,i) for i in range(1,k+1)):
        vals.append(val)
        vecs.append(vec)
    vals=pt.stack(vals,-1)
    vecs=pt.stack(vecs,-1)
    return vals,vecs

if __name__=="__main__":
    from peagang.data.dense.CommunitySmall import CommSmall
    import matplotlib
    matplotlib.use("TkAgg")
    import matplotlib.pyplot as plt
    import networkx as nx
    import numpy as np
    ds=CommSmall()
    ps=ds[0]
    a=pt.stack([p.A for p in [ds[i] for i in range(3)]])
    deg=a.sum(-1,keepdim=True)
    D=pt.diag_embed(deg)
    print(D.shape)
    L=a-D
    _,feat=our_dom_symeig(L,3,end_offset=1)
    print(feat.shape)
