from typing import Callable, Optional, Union

import torch
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.typing import Adj, OptTensor, PairOptTensor, PairTensor

try:
    from torch_cluster import knn
except ImportError:
    knn = None


class RelativeEdgeConv(MessagePassing):
    def __init__(self, nn: Callable, aggr: str = 'sum', **kwargs):
        super().__init__(aggr=aggr, **kwargs)
        self.nn = nn
        self.reset_parameters()

    def reset_parameters(self):
        super()
        reset(self.nn)


    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor:
        if isinstance(x, Tensor):
            x: PairTensor = (x, x)
        # propagate_type: (x: PairTensor)
        return self.propagate(edge_index, x=x, size=None)


    def message(self, x_i: Tensor, x_j: Tensor) -> Tensor:
        return self.nn(x_j - x_i)

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(nn={self.nn})'
    
    

class DistEdgeConv(MessagePassing):
    def __init__(self, nn: Callable, aggr: str = 'mean', **kwargs):
        super().__init__(aggr=aggr, **kwargs)
        self.nn = nn
        self.reset_parameters()

    def reset_parameters(self):
        super()
        reset(self.nn)


    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor:
        if isinstance(x, Tensor):
            x: PairTensor = (x, x)
        # propagate_type: (x: PairTensor)
        return self.propagate(edge_index, x=x, size=None)


    def message(self, x_i: Tensor, x_j: Tensor) -> Tensor:
        #print((x_j - x_i).norm(2, dim=1).sqrt().view(-1,1))
        return self.nn((x_j - x_i).norm(2, dim=1).sqrt().view(-1,1))

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(nn={self.nn})'