from typing import List, Tuple

import torch
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.models.backbone.input_adapter.edge_features.fully_connected_index import (
    get_fully_connected_index,
    get_gnn_index,
)
from src.models.backbone.input_adapter.edge_features.proba_ancestor import ProbaAncestor
from src.models.backbone.input_adapter.node_features.upscaler import Upscaler
from src.utils.backbone_utils import make_nn
from src.utils.sparse_utils import index_subset_mask


class InputAdapter(nn.Module):
    """InputAdapter for PyG models."""

    def __init__(
        self,
        n: int,
        d_hidden: int,
        d_edge: int,
        fully_connected: bool,
        node_features: List,
        edge_features: List = [],
    ) -> None:
        """Initialize the InputAdapter.

        :param n: Number of maximal nodes in one graph. Needed for padding.
        :param d_hidden: Dimension of hidden layer of PyG model.
        :param d_edge: Dimension of edge features.
        :param fully_connected: Whether to use a fully connected edge index.
        :param node_features: List of node features classes.
        :param edge_features: List of edge features classes. Defaults to None.
        """
        super().__init__()

        self.n = n
        self.d_hidden = d_hidden
        self.fully_connected = fully_connected

        self.proba_ancestor = ProbaAncestor(n)

        self.node_features = node_features
        self.d_node = 0
        # breakpoint()
        for feature in node_features:
            self.d_node += feature.d
        self.upscaler = Upscaler(d_input=self.d_node, d_output=d_hidden)

        self.edge_features = edge_features
        self.d_edge = 0
        for feature in edge_features:
            self.d_edge += feature.d
        self.edge_norm = nn.BatchNorm1d(self.d_edge)
        self.upscaler_edge = make_nn(3, self.d_edge, d_hidden, d_edge)

    def forward(
        self, data: Batch
    ) -> Tuple[
        Float[Tensor, "n_nodes d_node"],
        Float[Tensor, "2 n_edge"],
        Float[Tensor, "n_edge d_edge"],
    ]:
        """Computes node features and optionally edge features of batch object and returns them.

        :param data: PyG batch object.
        :return: Node features, edge index and edge features.
        """
        # breakpoint()
        data.x[data.parent_mask] = 0
        data.p_anc = self.proba_ancestor(data)

        features = []
        for feature in self.node_features:
            features.append(feature(data))
        x = self.upscaler(torch.cat(features, dim=1))
        edge_index = get_fully_connected_index(data)
        edge_features = []
        for feature in self.edge_features:
            edge_features.append(feature(data, edge_index))
        edge_attr = self.upscaler_edge(self.edge_norm(torch.cat(edge_features, dim=1)))
        if not self.fully_connected:
            gnn_edge_index = get_gnn_index(data)
            gnn_index_mask = index_subset_mask(edge_index, gnn_edge_index)
            gnn_edge_attr = edge_attr[gnn_index_mask]
            edge_index, edge_attr = gnn_edge_index, gnn_edge_attr
        return x, edge_index, edge_attr
