import torch
from utils.graph_utils import mask_adjs, mask_x


def prior_none():
    def prior_drift(x, adj, flags, t):
        return 0.0
    return prior_drift


def prior_eig(alpha):
    def prior_drift(x, adj, flags, t):
        with torch.enable_grad():
            adj_para = torch.nn.Parameter(adj)
            # deg_inv_sqrt = adj_para.sum(dim=-1).pow(-0.5)
            deg_inv_sqrt = adj_para.sum(dim=-1).clamp(min=1.).pow(-0.5)
            lap = torch.eye(adj.shape[-1], device=adj.device).repeat(adj.shape[0], 1, 1) - \
                    deg_inv_sqrt.unsqueeze(-1) * adj_para * deg_inv_sqrt.unsqueeze(-2)
            L, Q = torch.linalg.eigh(lap)

            lyapunov_func = L[:,1].sum()
            # lyapunov_func = L[:,2].sum()
            lyapunov_func.backward()
            
            drift = -adj_para.grad
            drift = mask_adjs(drift, flags) * alpha
        return drift
    return prior_drift


def prior_spectrum(alpha, spectrum, num_eig=5):
    def prior_drift(x, adj, flags, t):
        with torch.enable_grad():
            adj_para = torch.nn.Parameter(adj)
            deg_inv_sqrt = adj_para.sum(dim=-1).clamp(min=1.).pow(-0.5)
            lap = torch.eye(adj.shape[-1], device=adj.device).repeat(adj.shape[0], 1, 1) - \
                    deg_inv_sqrt.unsqueeze(-1) * adj_para * deg_inv_sqrt.unsqueeze(-2)
            L, Q = torch.linalg.eigh(lap)
            spec = spectrum[1:num_eig+1].to(adj.device) if num_eig>0 else spectrum[1:].to(adj.device)

            lyapunov_func = torch.square(L[:, 1:spec.shape[-1]+1].mean(dim=0) - spec).sum()
            lyapunov_func.backward()
            
            drift = -adj_para.grad
            drift = mask_adjs(drift, flags) * alpha
        return drift
    return prior_drift


def prior_deg(alpha):
    def prior_drift(x, adj, flags, t):
        criterion = torch.nn.CrossEntropyLoss()
        with torch.enable_grad():
            x_para = torch.nn.Parameter(x)

            lyapunov_func = criterion(x_para, adj.sum(-1)).sum()
            lyapunov_func.backward()
            
            drift = -x_para.grad
            drift = mask_x(drift, flags) * alpha

        return drift
    return prior_drift