from math import log

from typing import List, Optional
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, SparseTensor
from torch import Tensor

import torch
from torch.nn import Dropout, Linear, Module, Parameter, ReLU, Tanh
from torch_geometric.nn import global_add_pool, global_max_pool, global_mean_pool, GCNConv, GCN2Conv, GATConv
from torch_geometric.utils import to_dense_adj

from .kernel_vectors import kernel_vectors, ker_lapl

class ColumnDropout(Dropout):
    def forward(self, input: Tensor, training: bool = True) -> Tensor:
        if training:
            mask = torch.rand(input.size(0)) > self.p
            input = input * mask.to(input.device).float().view(-1, 1)
        return input

class SCT_Resid(Module):
    _cached_edge_index: Optional[OptPairTensor]
    _cached_adj_t: Optional[SparseTensor]
    _cached_ker: Optional[SparseTensor]

    def __init__(self, in_channels: int, out_channels: int,
                 alpha : int, theta : int, layer : int,
                 ker_vecs : OptTensor = None, dropout: int = 0.0,
                 indicators : OptTensor = None,
                 cached : bool = False,
                 ):

        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.cached = cached

        self._cached_ker = ker_vecs
        if indicators is not None:
            indicators = torch.tensor(indicators, dtype=torch.long)

        self.B = Linear(out_channels, out_channels)
        self.alpha = alpha
        self.theta = log(theta/ layer + 1)
        self.reset_parameters()

    def reset_parameters(self):
        self.B.reset_parameters()


    def forward(self, x: Tensor, edge_index: Adj,
                x0: OptTensor = None,
                edge_weight: OptTensor = None) -> Tensor:

        with torch.no_grad():
            self._cached_ker = self._cached_ker.to(x.device)
            if self._cached_ker is None:
                indicators, deg, ker_vecs = kernel_vectors(edge_index, edge_weight=edge_weight, return_all=True) # jb: get ker
                ker_vecs = ker_lapl(edge_index, edge_weight)
                if self.cached:
                    self._cached_ker = ker_vecs # jb: set cache
            else:
                ker_vecs = self._cached_ker # jb: pull cache
                
        A = (1-self.alpha)*self.B(ker_vecs@x)+self.alpha*self.B(ker_vecs@x0)
        alpha = A*torch.softmax(ker_vecs@x, dim=1)
        sum_ = torch.einsum('bi,bj->bij', alpha, ker_vecs)
        out = torch.sum(sum_ ,dim=0)
        return out.T