# postponed evaluation of annotations (helps avoid circular import)
from __future__ import annotations
from typing import TYPE_CHECKING

import torch
from torch import Tensor
from torch.nn import Module
from torch_geometric.data import Batch
from torch_geometric.utils import to_dense_adj, to_dense_batch
from typing_extensions import override

if TYPE_CHECKING:  # avoid circular import
    from ._config import ResidualGatedGCNConfig
from ._residual_gated_gcn import ResidualGatedGCNModel


class _Dictionary(dict):
    """
    A helper class that's necessary because the original code accesses some config items with `config.key` and some
    with `config["key"]` for some reason.
    """

    def __getattr__(self, key):
        return self[key]


class ResidualGatedGCN(Module):
    model: ResidualGatedGCNModel

    @override
    def __init__(self, config: ResidualGatedGCNConfig):
        super().__init__()

        residual_gated_gcn_config = _Dictionary(
            num_nodes=config.num_nodes,
            node_dim=config.node_channels,
            voc_nodes_in=None,  # input node signal vocabulary size (ignored, so we can set it to None)
            voc_edges_in=config.voc_edges_in,
            voc_edges_out=1,  # we only need one output score for each edge
            hidden_dim=config.hidden_channels,
            num_layers=config.gnn_layers,
            mlp_layers=config.mlp_layers,
            aggregation=config.gnn_aggregation,
        )

        # dtypeFloat and dtypeLong are ignored, so we can just pass None
        self.model = ResidualGatedGCNModel(residual_gated_gcn_config, dtypeFloat=None, dtypeLong=None)

    @override
    def forward(self, graphs: Batch) -> Tensor:
        # convert the pytorch geometric graphs to the adjacency matrix format that the model expects
        adjacency_matrix = to_dense_adj(graphs.edge_index, graphs.batch).int()
        adjacency_matrix_with_weights = to_dense_adj(graphs.edge_index, graphs.batch, graphs.edge_attr)
        node_coordinates, _ = to_dense_batch(graphs.x, graphs.batch)

        edge_predictions = self.model(
            x_edges=adjacency_matrix,
            x_edges_values=adjacency_matrix_with_weights,
            x_nodes=None,  # unused
            x_nodes_coord=node_coordinates,
        )

        edge_predictions.squeeze_(3)

        # convert predictions in adjacency matrix form to the representation that pytorch geometric uses
        edge_predictions = dense_to_sparse(edge_predictions)

        return edge_predictions


def dense_to_sparse(adjacency_matrix: Tensor) -> Tensor:
    """
    Converts a dense adjacency matrix to a sparse adjacency matrix defined by edge indices and edge attributes
    (though only the edge attributes are returned).

    Adapted from `torch_geometric.utils.dense_to_sparse()`.
    The problem with the original implementation is that it drops edges if their entry in the adjacency matrix is 0.
    This new implementation keeps all edges specified in the given `edge_index`, even if the corresponding value is 0.

    Parameters:
    - `adjacency_matrix`: Adjacency matrix. Size `[batch_size, num_nodes, num_nodes]`

    Returns `edge_attr` of size `[2, num_edges]`.
    """
    flatten_adj = adjacency_matrix.view(-1, adjacency_matrix.size(-1))

    edge_index_without_offset = torch.ones_like(flatten_adj).nonzero().t()

    offset = torch.arange(
        start=0,
        end=adjacency_matrix.size(0) * adjacency_matrix.size(2),
        step=adjacency_matrix.size(2),
        device=adjacency_matrix.device,
    )
    offset = offset.repeat_interleave(adjacency_matrix.size(1))

    edge_index = edge_index_without_offset.clone()
    edge_index[1] += offset[edge_index[0]]

    # remove self loops
    edge_index_without_self_loops = edge_index_without_offset[:, edge_index[0, :] != edge_index[1, :]]

    edge_attr = flatten_adj[edge_index_without_self_loops[0, :], edge_index_without_self_loops[1, :]]

    return edge_attr
