from typing import Union, Tuple, Callable
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Parameter
from torch_scatter import scatter_mean, scatter_sum

class ContinuousConv(nn.Module):
    def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int, nn: Callable, aggr: str = 'mean',
                bias: bool = True, **kwargs):
        super(ContinuousConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.nn = nn
        self.aggr = aggr
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
    
    def forward(self, x, edge_idx, edge_attr):
        x_from = x[:,edge_idx[0],:]
        batch_size, node_num, embed_size = x_from.shape
        assert embed_size == self.in_channels

        x_to = x[:,edge_idx[1],:]
        weight = self.nn(edge_attr)
        weight = weight.reshape(batch_size, node_num, self.in_channels, self.out_channels)
        message = torch.einsum('bni,bnij->bnj', x_from, weight)
        message += x_to * (1 + 1e-2)
        
        out = torch.zeros_like(x)
        if self.aggr == 'mean':
            out = scatter_mean(message, edge_idx[1], dim=-2, out=out)
        elif self.aggr == 'sum':
            out = scatter_sum(message, edge_idx[1], dim=-2, out=out)
        if self.bias is not None:
            out += self.bias
        return out


