from typing import Union, Callable, Any

import torch
import torch.nn as nn
from torch_geometric.data.hetero_data import HeteroData

from hmpn.abstract.abstract_modules import AbstractMetaModule
from hmpn.common.latent_mlp import LatentMLP


class Heterogeneous2MetaModule(AbstractMetaModule):
    """
    Base class for the heterogeneous modules used in the GNN.
    They are used for updating node-, edge-, and global features.
    """

    def __init__(
        self,
        *,
        scatter_reducers: Union[Callable, list[Callable]],
        use_global_features: bool,
    ):
        """
        Args:
            scatter_reducers: How to aggregate over the nodes/edges/globals. Can for example be [torch.scatter_mean]
            use_global_features: Whether global features are used
        """
        super().__init__(scatter_reducers=scatter_reducers)
        if use_global_features:
            self._maybe_global = self._concat_global
        else:
            self._maybe_global = lambda *args, **kwargs: args[0]

    def _concat_global(self, *args, **kwargs):
        raise NotImplementedError(
            f"Module {type(self)} needs to implement _concat_global(self,*args, **kwargs)"
        )


class Heterogeneous2EdgeModule(Heterogeneous2MetaModule):
    def __init__(
        self,
        *,
        in_dims: dict[str, int],
        out_dim: int,
        mlp_config: dict[str, Any],
        scatter_reducers: Union[Callable, list[Callable]],
        use_global_features: bool,
    ):
        super().__init__(
            scatter_reducers=scatter_reducers,
            use_global_features=use_global_features,
        )
        self.mlps = nn.ModuleDict(
            {
                str(name): LatentMLP(
                    in_features=num_feat,
                    latent_dimension=out_dim,
                    config=mlp_config,
                )
                for name, num_feat in in_dims.items()
            }
        )

    def forward(self, graph: HeteroData):
        """
        Compute edge updates for the edges of the Module for heterogeneous graphs
        Args:
            graph: HeteroData object of pytorch geometric. Represents a (batch of) of heterogeneous graph(s)
        Returns: An updated representation of the edge attributes for all edge_types
        """
        for edge_type in graph.edge_types:
            self.process_edge(edge_type, graph)

    def process_edge(self, edge_type: str, graph: HeteroData):
        src_type, _, dest_type = edge_type
        edge_store = graph.get_edge_store(*edge_type)

        src_indices, dest_indices = edge_store.edge_index

        src_feat = graph[src_type].x.index_select(0, src_indices)
        dest_feat = graph[dest_type].x.index_select(0, dest_indices)

        # concatenate everything
        feat = torch.cat([src_feat, dest_feat, edge_store.edge_attr], 1)
        feat = self._maybe_global(feat, graph, src_type, src_indices)

        edge_store.edge_attr = self.mlps[str(edge_type)](feat)

    def _concat_global(self, features, graph, src_node_type, src_indices):
        indices = graph[src_node_type].batch
        global_features = graph.u[indices[src_indices]]
        return torch.cat([features, global_features], 1)


class Heterogeneous2NodeModule(Heterogeneous2MetaModule):
    def __init__(
        self,
        *,
        in_dims: dict[str, int],
        out_dim: int,
        mlp_config,
        scatter_reducers: list[Callable],
        use_global_features: bool,
        flip_edges_for_nodes: bool,
    ):
        """
        Args:
            mlps: Dictionary of {node_type: List[MLP]} of MLPs to use for each respective node
                Implemented as nn.ModuleDict/nn.ModuleList to correctly register parameters
            num_edge_types: How many edge types feed into each kind of node type
            latent_dimension: Dimensionality of the latent space. Also corresponds to the dimension of each node/edge
              message
            scatter_reducers: How to aggregate over the nodes/edges/globals. Can for example be [torch.scatter_mean]
        """
        super().__init__(
            scatter_reducers=scatter_reducers,
            use_global_features=use_global_features,
        )
        self.mlps = nn.ModuleDict(
            {
                node_name: LatentMLP(
                    in_features=num_feat,
                    latent_dimension=out_dim,
                    config=mlp_config,
                )
                for node_name, num_feat in in_dims.items()
            }
        )

        # use the source indices for feat. aggregation if edges shall be flipped
        if flip_edges_for_nodes:
            self._get_edge_indices = lambda indices: indices[0]
        else:
            self._get_edge_indices = lambda indices: indices[1]

    def forward(self, graph: HeteroData):
        """
        Compute updates for each node feature vector as x_i' = f2(x_i, agg_j f1(e_ij, x_j), u),
        where f1 and f2 are MLPs
            graph: HeteroData object of pytorch geometric. Represents a (batch of) of heterogeneous graph(s)
        Returns: An updated representation of the edge attributes for all edge_types
        """
        for node_type in graph.node_types:
            self.process_node(node_type, graph)

    def process_node(self, node_type: str, graph: HeteroData):
        node_store = graph[node_type]

        edge_feat_list = self.get_edge_feat(node_type, graph)
        aggr_feat = torch.cat([node_store.x] + edge_feat_list, 1)
        aggr_feat = self._maybe_global(aggr_feat, graph, node_type)
        node_store.x = self.mlps[node_type](aggr_feat)

    def get_edge_feat(self, node_type: str, graph: HeteroData):
        num_nodes = graph[node_type].x.shape[0]
        edge_feat_list = []
        for edge_store in self.get_relevant_edge_stores(node_type, graph):
            scatter_edge_indices = self._get_edge_indices(edge_store.edge_index)
            edge_feat = self.multiscatter(
                features=edge_store.edge_attr,
                indices=scatter_edge_indices,
                dim=0,
                dim_size=num_nodes,
            )
            # aggr_features now has shape (num_nodes, latent_dimension * n_scatter_reducers)
            edge_feat_list.append(edge_feat)
        return edge_feat_list

    def get_relevant_edge_stores(self, node_type: str, graph: HeteroData):
        for edge_type, edge_store in graph.edge_items():
            # look for all edges that have the current node type as destination
            if edge_type[-1] == node_type:
                yield edge_store

    def _concat_global(self, features, graph, node_type):
        batch = graph[node_type].batch
        return torch.cat([features, graph.u[batch]], 1)


class Heterogeneous2GlobalModule(Heterogeneous2MetaModule):
    def __init__(
        self,
        *,
        in_dim: int,
        out_dim: int,
        mlp_config: dict[str, Any],
        scatter_reducers: Union[Callable, list[Callable]],
        use_global_features: bool,
    ):
        super().__init__(
            scatter_reducers=scatter_reducers,
            use_global_features=use_global_features,
        )
        self.mlp = LatentMLP(
            in_features=in_dim,
            latent_dimension=out_dim,
            config=mlp_config,
        )

    def forward(self, graph: HeteroData):
        """
        computes the forward pass for the global module
        Args:
            graph: of type torch_geometric.data.Batch

        Returns: None, in-place operation

        """

        aggr_feat = torch.cat(
            [self.get_node_feat(graph), self.get_edge_feat(graph), graph.u], 1
        )
        graph.u = self.mlp(aggr_feat)  # write back to graph

    def get_edge_feat(self, graph: HeteroData):
        edge_feat_list = []  # stores the reduced edge features per edge type
        for (src_type, _, _), edge_store in graph.edge_items():
            src_indices = edge_store.edge_index[0]
            batch = graph[src_type].batch[src_indices]
            # indices assigns each node to a graph in the batch of graphs. We use this to aggregate over the edges
            # by querying this for the source node of each edge
            reduced_edge_feat = self.multiscatter(
                features=edge_store.edge_attr,
                indices=batch,  # query source node of each edge
                dim=0,
                dim_size=graph.u.shape[0],
            )
            edge_feat_list.append(reduced_edge_feat)
        return torch.cat(edge_feat_list, 1)

    def get_node_feat(self, graph: HeteroData):
        node_feat_list = []  # stores the reduced node features per node type
        for node_type, node_store in graph.node_items():
            reduced_node_features = self.multiscatter(
                features=node_store.x,
                indices=graph[node_type].batch,
                dim=0,
                dim_size=graph.u.shape[0],
            )
            node_feat_list.append(reduced_node_features)
        return torch.cat(node_feat_list, 1)
