from typing import List, Optional
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, SparseTensor
from torch import Tensor
from torch.nn import Module

import torch

from .kernel_vectors import kernel_vectors

class NodeFeatureSmoothness(Module):
    _cached_ker: Optional[SparseTensor]
    _cached_ind: Optional[SparseTensor]

    def __init__(self,
                 ker_vecs : OptTensor = None,
                 cached: bool = False
                 ):

        super().__init__()

        self.cached = cached

        self._cached_ker = ker_vecs
        if self._cached_ker is not None:
            self._cached_proj = [torch.inner(ker_vec,ker_vec)*ker_vec for ker_vec in ker_vecs]

    def forward(self, x: Tensor, edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:

        if self._cached_ker is None:
            ker_vecs = kernel_vectors(edge_index, edge_weight=edge_weight) # jb: get ker
            proj = [torch.inner(ker_vec,ker_vec)*ker_vec for ker_vec in ker_vecs]
            if self.cached:
                self._cached_ker = ker_vecs # jb: set cache
                self._cached_proj = [torch.inner(ker_vec,ker_vec)*ker_vec for ker_vec in ker_vecs]

        else:
            ker_vecs = self._cached_ker # jb: pull cache
            proj = self._cached_proj

        s_z_m = []
        for j,col in enumerate(x.T):
            s_zi_m = []
            for i,kv in enumerate(ker_vecs):
                zi_m = torch.inner(col,kv)*proj[i]
                zi_m_norm = torch.norm(zi_m,2)
                s = (zi_m_norm/torch.norm(col,2)).item() if zi_m_norm>1e-16 else 1
                s_zi_m.append(s)
            s_z_m.append(s_zi_m)
        return s_z_m