from typing import Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import (
    Adj,
    OptTensor,
    PairTensor,
)
from torch_geometric.utils import (
    softmax,
)


class MC_GIN(MessagePassing):

    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        heads: int = 1,
        use_softmax: bool = False,
        nn = None,
        **kwargs,
    ):
        super().__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.use_softmax = use_softmax

        self.bias = Parameter(torch.empty(out_channels))

        if nn is None:        
            self.mlp = torch.nn.Sequential(
                torch.nn.Linear(2*in_channels, out_channels, bias=True),
                torch.nn.ReLU(),
                torch.nn.Linear(out_channels, out_channels, bias=True),
                torch.nn.ReLU(),
                torch.nn.Linear(out_channels, heads, bias=True)
            )
        else:
            self.mlp = nn

        self.lin = torch.nn.Linear(in_channels, out_channels * heads, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.lin.reset_parameters()
        zeros(self.bias)

    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, alpha=None):
        H, C = self.heads, self.out_channels

        x_prime = self.lin(x).view(-1, H, C)

        # propagate_type: (x: PairTensor, edge_attr: OptTensor)
        out = self.propagate(edge_index, x=(x_prime, None), a=(x, x), edge_attr=None,
                             size=None,alpha=alpha)

        out = out.sum(dim=1)

        if self.bias is not None:
            out = out + self.bias

        return out

    def message(self, x_j: Tensor, a_j: Tensor, a_i: Tensor, edge_attr: OptTensor,
                index: Tensor, ptr: OptTensor,
                size_i: Optional[int], alpha=None) -> Tensor:
        if alpha is None:
            alpha = torch.cat((a_i, a_j),dim=-1)
            alpha = self.mlp(alpha)
        self.alpha = alpha

        if self.use_softmax:
            alpha = softmax(alpha, index, ptr, size_i)
        return x_j* alpha.unsqueeze(-1)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, heads={self.heads})')
