import torch
import networkx as nx
import numpy as np 

def vn_entropy(k, eps=1e-20):

    k = k / torch.trace(k) 
    eigv = torch.abs(torch.symeig(k, eigenvectors=True)[0])
    entropy = -torch.sum(eigv[eigv>0] * torch.log(eigv[eigv>0] + eps))
    return entropy

def entropy_loss(sigma, rho, beta):

    assert(beta>=0), "beta shall be >=0"
    if beta > 0:
        return 0.5 * (1 - beta) / beta * vn_entropy(sigma) + vn_entropy(0.5 * (sigma + rho))
    else:
        return vn_entropy(sigma)

def sparse(G, tau, n_samples, max_iteration, lr, beta):
	'''
	Args:
		G: networkx Graph
		n_samples: number of samples for gumbel softmax
	'''
	
	E = nx.incidence_matrix(g1, oriented=True)
    E = E.todense().astype(np.double)
    E = torch.from_numpy(E)

    rho = E @ E.T

    m, n = G.number_of_edges(), G.number_of_nodes()   
    theta = torch.randn(m, 2, requires_grad=True)
    optimizer = torch.optim.Adam([theta], lr=lr)

    for itr in range(max_iteration):
        cost = 0      
        for sample in range(n_samples):
            # Sampling
            z = F.gumbel_softmax(theta, tau, hard = True)
            w = z[:, 1].squeeze()
            sigma = E @ torch.diag(w) @ E.T
            _loss = entropy_loss(sigma, rho, beta)                 
            cost = cost + _loss

        cost = cost / n_samples        
        cost.backward()
        optimizer.step()
        optimizer.zero_grad()

    z = F.gumbel_softmax(theta, tau, hard=True)
    w = z[:,1].squeeze()

    sigma = E @ torch.diag(w) @ E.T # sparse laplacian

    return sigma, w








