# postponed evaluation of annotations (helps avoid circular import)
from __future__ import annotations
from typing import TYPE_CHECKING

import torch
from torch import Tensor
from torch.nn import Module
from torch_geometric.nn.models import MLP
from torch_geometric.nn.models.basic_gnn import BasicGNN
from torch_geometric.data import Data
from typing_extensions import override

if TYPE_CHECKING:  # avoid circular import
    from ._config import SimpleGNNConfig


class SimpleGNN(Module):
    """
    This simple model works in two steps:
    First, a GNN is applied using random values as input node features.
    Then, for each edge, the node features are concatenated and passed through an MLP (here called edge classifier)
    that calculates the final edge-level output.
    """

    gnn: BasicGNN
    edge_classifier: MLP
    in_channels: int

    @override
    def __init__(self, config: SimpleGNNConfig):
        super().__init__()

        self.in_channels = config.in_channels

        gnn_kwargs = dict() if config.gnn_kwargs is None else config.gnn_kwargs
        self.gnn = config.get_gnn_architecture()(
            in_channels=config.in_channels,
            hidden_channels=config.gnn_hidden_channels,
            out_channels=config.gnn_out_channels,
            num_layers=config.gnn_layers,
            **gnn_kwargs,
        )

        self.edge_classifier = MLP(
            in_channels=config.gnn_out_channels * 2 + 1,
            hidden_channels=config.edge_classifier_hidden_channels,
            out_channels=1,
            num_layers=config.edge_classifier_layers,
            act_first=True,  # use activation function before the first layer (last GCN layer does not have one)
        )

    @override
    def forward(self, graph: Data) -> Tensor:
        if graph.x is not None:
            x = graph.x
        else:
            x = torch.randn([graph.num_nodes, self.in_channels], device=graph.edge_index.device)

        edge_index = graph.edge_index

        out_gnn = self.gnn(x, edge_index, edge_attr=graph.edge_attr)

        edge_features = []
        for i in range(edge_index.size(1)):
            node_a, node_b = edge_index[:, i]
            features = torch.cat((out_gnn[node_a], out_gnn[node_b]))
            edge_features.append(features)
        edge_features = torch.stack(edge_features, dim=0)
        edge_features = torch.cat((edge_features, graph.edge_attr.unsqueeze(1)), dim=1)
        edge_predictions = self.edge_classifier(edge_features).squeeze()

        return edge_predictions
