from typing import List
from torch import Tensor
from torch_geometric.typing import Adj, OptTensor, SparseTensor

import torch
from torch_geometric.utils import degree, to_dense_adj
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_scatter import scatter

from .connected_components import connected_components

# jb: generate eigen space
def ker_lapl(edge_index,edge_weight):
    aug = to_dense_adj(edge_index,edge_attr=edge_weight)[0]
    assert torch.all(torch.eq(aug,aug.T)), 'Augmented adjacency is not symmetric'

    eigs, vecs = torch.linalg.eig(aug)
    perp = torch.isclose(eigs.real,torch.tensor(1,dtype=torch.float32),atol=1e-3,rtol=1e-3)
    perp_index = torch.where(perp)
    return vecs.T[perp_index].real #jb: should be real only


def kernel_vectors(
    edge_index : Adj,
    flow : str ='source_to_target',
    edge_weight : OptTensor = None,
    indicators : List[Tensor] = None,
    largest : bool = False,
    return_all : bool = False,
    single: bool = False,
)-> List[Tensor]:

    if isinstance(edge_index, SparseTensor):
        if not edge_index.has_value():
            edge_index = edge_index.fill_value(1., dtype=dtype)
        if add_self_loops:
            edge_index = fill_diag(edge_index, 1.)
        deg = sparsesum(edge_index, dim=1)
        deg_inv_sqrt = deg.pow_(-0.5)
    else:
        num_nodes = maybe_num_nodes(edge_index)
        row, col = edge_index[0], edge_index[1]
        idx = col if flow == "source_to_target" else row
        deg = degree(edge_index[0], num_nodes=num_nodes, dtype=torch.float)
        # deg = sparsesum(edge_index, dim=1)
        # deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum')

    deg_sqrt = deg.pow_(0.5)
    if single:
        ker_vecs = deg.view(1,-1)
        indicator = torch.ones_like(deg)
    else:
        if not indicators:
            indicators = connected_components(edge_index, num_nodes, edge_weight=edge_weight, largest=largest)

        ker_vecs = []
        for i in range(max(indicators)+1):
            indicator = [1 if ind==i else 0 for ind in indicators]
            indicator = torch.tensor(indicator,dtype=torch.long,device=edge_index.device)
            kv = deg_sqrt * indicator
            ker_vecs += [ kv/torch.norm(kv) ]
        ker_vecs = torch.stack(ker_vecs)

    if return_all:
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
        return indicators, deg_inv_sqrt, ker_vecs
    else:
        return ker_vecs
