
import numpy as np
from typing import Optional, Tuple

import torch
from torch import Tensor

from torch_geometric.typing import OptTensor
from torch_geometric.utils import add_self_loops, remove_self_loops, scatter
from torch_geometric.utils.num_nodes import maybe_num_nodes


def calc_norm(
    edge_index: Tensor,
    edge_weight: OptTensor = None,
    dtype: Optional[torch.dtype] = None,
) -> float:
    if edge_weight == None:
        edge_weight = torch.ones(edge_index.size(1), dtype=dtype,device=edge_index.device)
        

        

    A = torch.sparse_coo_tensor(edge_index, edge_weight).to_dense()
    return torch.linalg.matrix_norm(A,2)
    






def get_laplacian(
    edge_index: Tensor,
    edge_weight: OptTensor = None,
    normalization: Optional[bool] = False,
    dtype: Optional[torch.dtype] = None,
    num_nodes: Optional[int] = None,
) -> Tuple[Tensor, OptTensor]:
    
    edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
    if edge_weight is None:
        edge_weight = torch.ones(edge_index.size(1), dtype=dtype,
                                 device=edge_index.device)
    
    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    row, col = edge_index[0], edge_index[1]
    deg = scatter(edge_weight, row, 0, dim_size=num_nodes, reduce='sum')
    

    
        # L = D - A.
    edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
    edge_weight = torch.cat([-edge_weight, deg], dim=0)
    if normalization == True:
        edge_weight = edge_weight/calc_norm(edge_index, edge_weight)  ### Doublecheck that this is the correct weighting

    return edge_index, edge_weight
    





def get_resolvent(
    edge_index: Tensor,
    edge_weight: OptTensor = None,
    omega = float,
    normalization: Optional[bool] = False,
    dtype: Optional[torch.dtype] = None,
    num_nodes: Optional[int] = None,
    normalizing_factor: Optional[float] = None,
) -> Tuple[Tensor, OptTensor]:
    assert omega < 0

    L_indices = get_laplacian(edge_index, edge_weight, normalization)
    L =  torch.sparse_coo_tensor(L_indices[0], L_indices[1])
    L = L.to_dense()
    if normalizing_factor is not None:
        L = L/normalizing_factor
    
    identity = torch.eye(L.size(0))
    # T = L - self.omega*identity
    T = L - omega*identity
    R = torch.linalg.inv(T).to_sparse()

    Redge_index, Redge_attr = R.indices(), R.values()
    return Redge_index, Redge_attr
    





