# postponed evaluation of annotations (helps avoid circular import)
from __future__ import annotations
from typing import TYPE_CHECKING

import torch
from torch import Tensor
from torch.nn import Module
from torch_geometric.nn.models import MLP
from torch_geometric.nn.models.basic_gnn import BasicGNN
from torch_geometric.data import Data
from torch_geometric.transforms import LineGraph
from typing_extensions import override

if TYPE_CHECKING:  # avoid circular import
    from ._config import LineGraphGNNConfig


class LineGraphGNN(Module):
    """
    Transforms a given graph into its line graph, then runs a GNN on that line graph.
    The resulting node features of the line graph (= edge features of the original graph) are fed into an MLP (here
    called edge classifier) that produces the final edge-level output.
    """

    gnn: BasicGNN
    edge_classifier: MLP
    in_channels: int
    graph_transformation: LineGraph

    @override
    def __init__(self, config: LineGraphGNNConfig):
        super().__init__()

        self.in_channels = config.in_channels
        # TODO force_directed=True might not be the best way to do this
        self.graph_transformation = LineGraph(force_directed=True)

        gnn_kwargs = dict() if config.gnn_kwargs is None else config.gnn_kwargs
        self.gnn = config.get_gnn_architecture()(
            in_channels=config.in_channels,
            hidden_channels=config.gnn_hidden_channels,
            out_channels=config.gnn_out_channels,
            num_layers=config.gnn_layers,
            **gnn_kwargs,
        )

        self.edge_classifier = MLP(
            in_channels=config.gnn_out_channels + 1,
            hidden_channels=config.edge_classifier_hidden_channels,
            out_channels=1,
            num_layers=config.edge_classifier_layers,
            act_first=True,  # use activation function before the first layer (last GCN layer does not have one)
        )

    @override
    def forward(self, graph: Data) -> Tensor:
        line_graph = self.graph_transformation(graph.clone())
        x = line_graph.x.unsqueeze(1)

        edge_features = self.gnn(x, line_graph.edge_index)
        edge_features = torch.cat((edge_features, x), dim=1)
        edge_predictions = self.edge_classifier(edge_features).squeeze()

        return edge_predictions
