import math
import typing
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (
    Adj,
    NoneType,
    OptTensor,
    PairTensor,
    SparseTensor,
)
from torch_geometric.utils import softmax

if typing.TYPE_CHECKING:
    from typing import overload
else:
    from torch.jit import _overload_method as overload


from jaxtyping import Float
from torch import Tensor, nn


class Transformer(nn.Module):
    """Transformer implemented with PyG framework."""

    def __init__(self, n_layers: int, d_hidden: int, d_edge: int, n_head: int, drop_prob: float = 0.1) -> None:
        """Initialize transformer.

        :param n_layers: Number of layers.
        :param d_hidden: Hidden dimension.
        :param d_edge: Edge dimension.
        :param n_head: Number of attention heads.
        :param drop_prob: Dropout probability. Defaults to 0.1.
        """
        super().__init__()

        layers = [
            TransformerLayer(d_hidden, d_edge, n_head, drop_prob) for _ in range(n_layers)
        ]
        self.layers = nn.ModuleList(layers)

    def forward(
        self,
        x: Float[Tensor, "n_nodes d_hidden"],
        edge_index: Float[Tensor, "2 n_edges"],
        edge_attr: Float[Tensor, "n_edges d_edge"],
    ) -> Float[Tensor, "n_nodes d_hidden"]:
        """Perform a forward pass with the residual graph neural network.

        :param x: Tensor containing node features.
        :param edge_index: Tensor containing edge information.
        :param edge_attr: Tensor containing edge attributes.
        :return: Updated node features.
        """
        for layer in self.layers:
            x = layer(x, edge_index, edge_attr)
        return x


class TransformerLayer(nn.Module):
    """Transformer layer used by the transformer."""

    def __init__(
        self, d_hidden: int, d_edge: int, n_head: int, drop_prob: float = 0.1
    ) -> None:
        """Initialize the transformer layer.

        :param d_hidden: Number of hidden channels.
        :param d_edge: Number of edge channels.
        :param n_head: Number of attention heads.
        :param drop_prob: Dropout probability. Defaults to 0.1.
        """
        super().__init__()

        self.attention = TransformerConv(
            in_channels=d_hidden,
            out_channels=d_hidden // n_head,
            heads=n_head,
            beta=True,
            edge_dim=d_edge,
        )
        self.attention_dropout = nn.Dropout(p=drop_prob)
        self.norm_1 = nn.LayerNorm(d_hidden)
        self.ffn = PositionwiseFeedForward(d_hidden=d_hidden, hidden=d_hidden)
        self.norm_2 = nn.LayerNorm(d_hidden)
        self.layer_dropout = nn.Dropout(p=drop_prob)

    def forward(
        self,
        x: Float[Tensor, "n_nodes d_hidden"],
        edge_index: Float[Tensor, "2 n_edges"],
        edge_attr: Float[Tensor, "n_edges d_edge"],
    ) -> Float[Tensor, "n_nodes d_hidden"]:
        """Forward pass with Transformer Layer.

        :param x: Tensor containing node features.
        :param edge_index: Tensor containing edge information.
        :param edge_attr: Tensor containing edge attributes.
        :return: Updated node features.
        """
        _x = x
        x = self.attention(x=x, edge_index=edge_index, edge_attr=edge_attr)
        x = self.attention_dropout(x)
        x = self.norm_1(x + _x)

        _x = x
        x = self.ffn(x)
        x = self.norm_2(x + _x)
        x = self.layer_dropout(x)
        return x


class PositionwiseFeedForward(nn.Module):
    """Positionwise feedforward neural network."""

    def __init__(self, d_hidden: int, hidden: int) -> None:
        """Initialize positionwise feedforward neural network.

        :param d_hidden: The number of hidden channels.
        :param hidden: The number of hidden channels for the fully connected neural network.
        :param drop_prob: Dropout probability. Defaults to 0.1.
        """
        super().__init__()

        self.linear1 = nn.Linear(d_hidden, hidden)
        self.gelu = nn.GELU()
        self.linear2 = nn.Linear(hidden, d_hidden)

    def forward(
        self, x: Float[Tensor, "n_nodes d_hidden"]
    ) -> Float[Tensor, "n_nodes d_hidden"]:
        """Forward pass with Positionwise Feedforward Network.

        :param x: Tensor containing node features.
        :return: Updated node features.
        """
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.linear2(x)
        return x


class TransformerConv(MessagePassing):
    _alpha: OptTensor

    def __init__(
        self,
        in_channels: Union[int, Tuple[int, int]],
        out_channels: int,
        heads: int = 1,
        concat: bool = True,
        beta: bool = False,
        dropout: float = 0.,
        edge_dim: Optional[int] = None,
        bias: bool = True,
        root_weight: bool = True,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.beta = beta and root_weight
        self.root_weight = root_weight
        self.concat = concat
        self.dropout = dropout
        self.edge_dim = edge_dim
        self._alpha = None

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_key = Linear(in_channels[0], heads * out_channels)
        self.lin_query = Linear(in_channels[1], heads * out_channels)
        self.lin_value = Linear(in_channels[0], heads * out_channels)
        if edge_dim is not None:
            self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
            self.lin_edge_val = Linear(edge_dim, heads * out_channels, bias=False)
        else:
            self.lin_edge = self.register_parameter('lin_edge', None)
            self.lin_edge_val = self.register_parameter('lin_edge', None)

        if concat:
            self.lin_skip = Linear(in_channels[1], heads * out_channels,
                                   bias=bias)
            if self.beta:
                self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)
            else:
                self.lin_beta = self.register_parameter('lin_beta', None)
        else:
            self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)
            if self.beta:
                self.lin_beta = Linear(3 * out_channels, 1, bias=False)
            else:
                self.lin_beta = self.register_parameter('lin_beta', None)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.lin_key.reset_parameters()
        self.lin_query.reset_parameters()
        self.lin_value.reset_parameters()
        if self.edge_dim:
            self.lin_edge.reset_parameters()
        self.lin_skip.reset_parameters()
        if self.beta:
            self.lin_beta.reset_parameters()

    @overload
    def forward(
        self,
        x: Union[Tensor, PairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        return_attention_weights: NoneType = None,
    ) -> Tensor:
        pass

    @overload
    def forward(  # noqa: F811
        self,
        x: Union[Tensor, PairTensor],
        edge_index: Tensor,
        edge_attr: OptTensor = None,
        return_attention_weights: bool = None,
    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        pass

    @overload
    def forward(  # noqa: F811
        self,
        x: Union[Tensor, PairTensor],
        edge_index: SparseTensor,
        edge_attr: OptTensor = None,
        return_attention_weights: bool = None,
    ) -> Tuple[Tensor, SparseTensor]:
        pass

    def forward(  # noqa: F811
        self,
        x: Union[Tensor, PairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        return_attention_weights: Optional[bool] = None,
    ) -> Union[
            Tensor,
            Tuple[Tensor, Tuple[Tensor, Tensor]],
            Tuple[Tensor, SparseTensor],
    ]:
        r"""Runs the forward pass of the module.

        Args:
            x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node
                features.
            edge_index (torch.Tensor or SparseTensor): The edge indices.
            edge_attr (torch.Tensor, optional): The edge features.
                (default: :obj:`None`)
            return_attention_weights (bool, optional): If set to :obj:`True`,
                will additionally return the tuple
                :obj:`(edge_index, attention_weights)`, holding the computed
                attention weights for each edge. (default: :obj:`None`)
        """
        H, C = self.heads, self.out_channels

        if isinstance(x, Tensor):
            x = (x, x)

        query = self.lin_query(x[1]).view(-1, H, C)
        key = self.lin_key(x[0]).view(-1, H, C)
        value = self.lin_value(x[0]).view(-1, H, C)

        # propagate_type: (query: Tensor, key:Tensor, value: Tensor,
        #                  edge_attr: OptTensor)
        out = self.propagate(edge_index, query=query, key=key, value=value,
                             edge_attr=edge_attr)

        alpha = self._alpha
        self._alpha = None

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.root_weight:
            x_r = self.lin_skip(x[1])
            if self.lin_beta is not None:
                beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))
                beta = beta.sigmoid()
                out = beta * x_r + (1 - beta) * out
            else:
                out = out + x_r

        if isinstance(return_attention_weights, bool):
            assert alpha is not None
            if isinstance(edge_index, Tensor):
                return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out

    def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,
                edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:

        if self.lin_edge is not None:
            assert edge_attr is not None
            edge_attr_key = self.lin_edge(edge_attr).view(-1, self.heads,
                                                      self.out_channels)
            edge_attr_val = self.lin_edge_val(edge_attr).view(-1, self.heads,
                                                                self.out_channels)
            key_j = key_j + edge_attr_key

        alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)
        alpha = softmax(alpha, index, ptr, size_i)
        self._alpha = alpha
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)

        out = value_j
        if edge_attr is not None:
            out = out + edge_attr_val

        out = out * alpha.view(-1, self.heads, 1)
        return out

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, heads={self.heads})')
