import torch
from torch import nn

from typing import Union, Tuple
from torch import Tensor
import torch_geometric.nn as pygnn
from torch_scatter import scatter_sum, scatter_mean
from torch_geometric.data import Batch as PyGBatch, Data as PyGData

from .builder import AFFINITY_HEAD


@AFFINITY_HEAD.register()
class GraphEncoder(nn.Module):
    def __init__(
        self,
        input_node_dim: int,
        input_edge_dim: int,
        hidden_dim: int,
        out_dim: int,
        num_convs: int,
    ):
        super(GraphEncoder, self).__init__()
        self.graph_channels: int = out_dim
        self.atom_channels: int = out_dim

        # Ligand Encoding
        self.node_layer = nn.Embedding(input_node_dim, hidden_dim)
        self.edge_layer = nn.Embedding(input_edge_dim, hidden_dim)
        self.conv_list = nn.ModuleList([
            pygnn.GINEConv(
                nn=nn.Sequential(pygnn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU()),
                edge_dim=hidden_dim,
            )
            for _ in range(num_convs)
        ])
        self.readout_layer = nn.Linear(hidden_dim * 2, out_dim)
        self.readout_gate = nn.Linear(hidden_dim * 2, out_dim)

        self.head = nn.Sequential(
            nn.Linear(hidden_dim, out_dim),
            nn.LayerNorm(out_dim)
        )

    def initialize_weights(self):
        def _init_weight(m):
            if isinstance(m, (nn.Linear)):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Embedding):
                m.weight.data.uniform_(-1, 1)
        self.apply(_init_weight)

    def forward(
        self,
        data: Union[PyGData, PyGBatch],
    ) -> Tuple[Tensor, Tensor]:
        """Affinity Prediction

        Args:
            x: Node Feature
            edge_attr: Edge Feature
            edge_index: Edge Index

        Returns:
            updated_data: Union[PyGData, PyGBatch]
        """
        x: Tensor = self.node_layer(data.x)
        edge_attr: Tensor = self.edge_layer(data.edge_attr)

        skip_x = x
        edge_index = data.edge_index
        for layer in self.conv_list:
            x = layer(x, edge_index, edge_attr)

        x = skip_x + x
        X = self.head(x)

        if isinstance(data, PyGBatch):
            Z1 = scatter_sum(x, data.batch, dim=0, dim_size=data.num_graphs)         # V, Fh -> N, Fh
            Z2 = scatter_mean(x, data.batch, dim=0, dim_size=data.num_graphs)        # V, Fh -> N, Fh
        else:
            Z1 = x.sum(0, keepdim=True)                     # V, Fh -> 1, Fh
            Z2 = x.mean(0, keepdim=True)                    # V, Fh -> 1, Fh
        Z = torch.cat([Z1, Z2], dim=-1)
        Z = self.readout_gate(Z) * self.readout_layer(Z)    # [N, Fh]
        return X, Z
