import numpy as np
from pathlib import Path
import torch 
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import get_laplacian
from libs.torch_utils import to_torch_coo_tensor

def diagonalize(
    edge_index: Adj,
    edge_weight: OptTensor = None,
    path: Path = None,
    cache: bool = True,
):
    Ufile = None
    Lfile = None

    if path is not None:
        path = Path(path)
        path.mkdir(parents=True, exist_ok=True)
        Ufile = path / Path("U.pt")
        Lfile = path / Path("lam.pt")

    if cache and path is not None:
        if Ufile.is_file() and Lfile.is_file():
            print("Loaded from cache")
            U = torch.load(Ufile)
            lam = torch.load(Lfile)
            return lam, U

    eidx, ew = get_laplacian(edge_index, normalization="sym")
    L = to_torch_coo_tensor(eidx, ew)
    lam, U = torch.linalg.eigh(L.to_dense())
    I = torch.eye(L.shape[0]).to(L.device).to_sparse()
    L = 2 / lam.max() * L - I
    lam, U = torch.linalg.eigh(L.to_dense())

    if cache and path is not None:
        print("Dumping to file for later use")
        torch.save(U, Ufile)
        torch.save(lam, Lfile)

    return lam, U


def cutoff_sparsifier(matrix, cutoff):
    matrix[torch.abs(matrix) < cutoff] = 0
    return matrix


def topk_sparsifier(matrix, k):
    print(f"TOPK SPARSIFIER WITH K={k}")
    N = matrix.shape[1]
    k, dim = min(N, k), 0
    sort_idx = torch.argsort(matrix, dim=dim, descending=True)
    top_idx = sort_idx[:k]
    edge_weight = torch.gather(matrix, dim=dim,
                                       index=top_idx).flatten()
    row_idx = torch.arange(0, N, device=matrix.device).repeat(k)
    edge_index = torch.stack([top_idx.flatten(), row_idx], dim=0)
    matrix = to_torch_coo_tensor(edge_index, edge_weight).to_dense()
    return matrix


def band_pass_filter(edge_index, e0, sigma, cutoff, eigen_path, eigen_cache, sparsifier='cutoff'):
    print(f"EIGEN_PATH: {eigen_path}")
    lam, U = diagonalize(edge_index, path=eigen_path, cache=eigen_cache)

    lx0 = e0 - sigma
    rx0 = e0 + sigma

    left = 10 * (lam - lx0)
    right = 10 * (-lam + rx0)
    filter = torch.special.expit(left) * torch.special.expit(right)
    filter = filter.to(lam.device)
    Ltilde = U @ torch.diag(lam * filter) @ U.T

    if sparsifier == 'cutoff':
        Ltilde = cutoff_sparsifier(Ltilde, cutoff)
    elif sparsifier == 'topk':
        Ltilde = topk_sparsifier(Ltilde, cutoff)

    D = torch.abs(Ltilde).sum(axis=0).pow_(-0.5)
    D = D.masked_fill_(D == float("inf"), -0.5)
    D = torch.diag(D)
    Ltilde = D @ Ltilde @ D
    Ltilde = Ltilde.to_sparse()
    Ltilde = Ltilde.coalesce()
    Ltilde = Ltilde.to(edge_index.device)
    return Ltilde


def wks_filter(edge_index, e0, sigma, cutoff, eigen_path, eigen_cache, sparsifier='cutoff'):
    lam, U = diagonalize(edge_index, path=eigen_path, cache=eigen_cache)

    loglam = torch.log(lam)
    loglam = torch.nan_to_num(loglam, nan=0, neginf=-100, posinf=100)
    arg = (torch.log(torch.tensor(e0)) - loglam).pow_(2) / (2 * np.power(sigma, 2))

    filter = torch.exp(-arg)
    filter = filter.to(lam.device)
    Ltilde = U @ torch.diag(lam * filter) @ U.T
    if sparsifier == 'cutoff':
        Ltilde = cutoff_sparsifier(Ltilde, cutoff)
    elif sparsifier == 'topk':
        Ltilde = topk_sparsifier(Ltilde, cutoff)

    D = torch.abs(Ltilde).sum(axis=0).pow_(-0.5)
    D = D.masked_fill_(D == float("inf"), 0.0)
    D = torch.diag(D)
    Ltilde = D @ Ltilde @ D
    Ltilde = Ltilde.to_sparse()
    Ltilde = Ltilde.coalesce()
    Ltilde = Ltilde.to(edge_index.device)
    return Ltilde


def gaussian_wks_filter(edge_index, e0, sigma, cutoff, eigen_path, eigen_cache, sparsifier='cutoff'):
    lam, U = diagonalize(edge_index, path=eigen_path, cache=eigen_cache)

    arg = (e0 - lam).pow_(2) / (2 * np.power(sigma, 2))
    filter = torch.exp(-arg)
    filter = filter.to(U.device)
    Ltilde = U @ torch.diag(lam * filter) @ U.T
    if sparsifier == 'cutoff':
        Ltilde = cutoff_sparsifier(Ltilde, cutoff)
    elif sparsifier == 'topk':
        Ltilde = topk_sparsifier(Ltilde, cutoff)

    D = torch.abs(Ltilde).sum(axis=0).pow_(-0.5)
    D = D.masked_fill_(D == float("inf"), 0.0)
    D = torch.diag(D)
    Ltilde = D @ Ltilde @ D
    Ltilde = Ltilde.to_sparse()
    Ltilde = Ltilde.coalesce()
    Ltilde = Ltilde.to(edge_index.device)
    return Ltilde