from typing import List, Optional
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, SparseTensor
from torch import Tensor

import torch
from torch.nn import Linear, Module, Parameter, ReLU, Tanh
from torch_geometric.nn import global_add_pool, global_max_pool, global_mean_pool, GCNConv, GCN2Conv
from torch_geometric.utils import to_dense_adj

from .kernel_vectors import kernel_vectors, ker_lapl

class SCT(Module):
    _cached_edge_index: Optional[OptPairTensor]
    _cached_adj_t: Optional[SparseTensor]
    _cached_ker: Optional[SparseTensor]

    def __init__(self, in_channels: int, out_channels: int,
                 ker_vecs : OptTensor = None,
                 indicators : OptTensor = None,
                 cached : bool = False,
                 single: bool = False,
                 ):

        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.cached = cached
        self.single = single

        self._cached_ker = ker_vecs
        if indicators is not None:
            indicators = torch.tensor(indicators, dtype=torch.long)
        self._cached_ind = indicators # jb: set cache

        self.alpha = Linear(in_channels, out_channels)
        self.reset_parameters()

    def reset_parameters(self):
        self.alpha.reset_parameters()


    def forward(self, x: Tensor, edge_index: Adj,
                x0: OptTensor = None,
                edge_weight: OptTensor = None) -> Tensor:

        with torch.no_grad():
            if self._cached_ker is None:
                indicators, deg, ker_vecs = kernel_vectors(edge_index, edge_weight=edge_weight, return_all=True, single=self.single) # jb: get ker
                ker_vecs = ker_lapl(edge_index, edge_weight)
                if self.cached:
                    self._cached_ker = ker_vecs # jb: set cache
                    self._cached_ker = self._cached_ker.to(x.device)
            else:
                self._cached_ker = self._cached_ker.to(x.device)
                ker_vecs = self._cached_ker # jb: pull cache
                
        alpha = self.alpha.weight.T
        sum_ = torch.einsum('bi,bj->bij', alpha*(ker_vecs@x), ker_vecs)
        out = torch.sum(sum_,dim=0)
        return out.T
