import os, sys, wandb, numpy as np, torch
from pathlib import Path

import metrics

file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory -> dpg/

from data.school_networks.school_networks import SchoolNetworks
from metrics import compute_metrics



def network_deconvolution(x, alpha=1.0):
    assert x.ndim == 3
    batch_size, N, N1 = x.shape
    assert N == N1

    #vals, vecs = torch.symeig(x, eigenvectors=True)
    vals, vecs = torch.linalg.eigh(x)
    #assert torch.max(vals) <= 1.0, f'network deconv: max eigenvalue should be normalized'

    nd_vals = vals / (1 + alpha*vals)  # network deco
    nd = torch.matmul(torch.matmul(vecs, torch.diag_embed(nd_vals)), torch.transpose(vecs, 1, 2) )
    return nd


if __name__ == "__main__":
    # train()
