from typing import Dict, Union, List, Callable

import torch
import torch.nn as nn
from torch_geometric.data.batch import Batch

from hmpn.abstract.abstract_modules import AbstractMetaModule


class HeterogeneousMetaModule(AbstractMetaModule):
    """
    Base class for the heterogeneous modules used in the GNN.
    They are used for updating node-, edge-, and global features.
    """

    def __init__(self, *,
                 mlps: nn.ModuleDict,
                 scatter_reducers: Union[Callable, List[Callable]],
                 use_global_features: bool,
                 latent_dimension: int):
        """
        Args:
            mlps: Dictionary of {{{node, edge}_type: List[MLP]} of MLPs to use for each respective node/edge.
                Implemented as nn.ModuleDict/nn.ModuleList to correctly register parameters
            scatter_reducers: How to aggregate over the nodes/edges/globals. Can for example be [torch.scatter_mean]
            use_global_features: Whether global features are used
        """
        super().__init__(scatter_reducers=scatter_reducers)
        self._mlps = mlps
        if use_global_features:
            self._maybe_global = self._concat_global
        else:
            self._maybe_global = lambda *args, **kwargs: args[0]

        self.latent_dimension = latent_dimension  # latent_dimension

    def _concat_global(self, *args, **kwargs):
        raise NotImplementedError(f"Module {type(self)} needs to implement _concat_global(self,*args, **kwargs)")


class HeterogeneousEdgeModule(HeterogeneousMetaModule):
    def forward(self, graph: Batch):
        """
        Compute edge updates for the edges of the Module for heterogeneous graphs
        Args:
            graph: HeteroData object of pytorch geometric. Represents a (batch of) of heterogeneous graph(s)
        Returns: An updated representation of the edge attributes for all edge_types
        """
        for position, (edge_type, edge_store) in enumerate(zip(graph.edge_types, graph.edge_stores)):
            edge_attr = edge_store.get("edge_attr")
            edge_indices = edge_store.get("edge_index")
            source_indices, dest_indices = edge_indices

            source_node_type, _, dest_node_type = edge_type
            source_node_index = graph.node_types.index(source_node_type)
            dest_node_index = graph.node_types.index(dest_node_type)

            edge_source_nodes = graph.node_stores[source_node_index]["x"][source_indices]
            edge_dest_nodes = graph.node_stores[dest_node_index]["x"][dest_indices]

            # concatenate everything
            aggregated_features = torch.cat([edge_source_nodes, edge_dest_nodes, edge_attr], 1)
            aggregated_features = self._maybe_global(aggregated_features, graph, source_node_type, source_indices)

            edge_store["edge_attr"] = self._mlps["".join(edge_type)](aggregated_features)

    def _concat_global(self, features, graph, source_node_type, source_indices):
        indices = graph[source_node_type].batch
        global_features = graph.u[indices[source_indices]]
        return torch.cat([features, global_features], 1)


class HeterogeneousNodeModule(HeterogeneousMetaModule):
    def __init__(self, *,
                 mlps: nn.ModuleDict,
                 num_edge_types: Dict[str, int],
                 latent_dimension: int,
                 scatter_reducers: List[Callable],
                 use_global_features: bool,
                 flip_edges_for_nodes: bool):
        """
        Args:
            mlps: Dictionary of {node_type: List[MLP]} of MLPs to use for each respective node
                Implemented as nn.ModuleDict/nn.ModuleList to correctly register parameters
            num_edge_types: How many edge types feed into each kind of node type
            latent_dimension: Dimensionality of the latent space. Also corresponds to the dimension of each node/edge
              message
            scatter_reducers: How to aggregate over the nodes/edges/globals. Can for example be [torch.scatter_mean]
        """

        super().__init__(mlps=mlps,
                         scatter_reducers=scatter_reducers,
                         use_global_features=use_global_features,
                         latent_dimension=latent_dimension)

        self.in_features = {node_type: mlp_list.in_features
                            for node_type, mlp_list in mlps.items()}  # in_features
        self.num_edge_types = num_edge_types

        # use the source indices for feat. aggregation if edges shall be flipped
        if flip_edges_for_nodes:
            self._get_edge_indices = lambda src_and_dest_indices: src_and_dest_indices[0]
        else:
            self._get_edge_indices = lambda src_and_dest_indices: src_and_dest_indices[1]

    def forward(self, graph: Batch):
        """
        Compute updates for each node feature vector as x_i' = f2(x_i, agg_j f1(e_ij, x_j), u),
        where f1 and f2 are MLPs
            graph: HeteroData object of pytorch geometric. Represents a (batch of) of heterogeneous graph(s)
        Returns: An updated representation of the edge attributes for all edge_types
        """
        for position, (node_type, node_store) in enumerate(zip(graph.node_types, graph.node_stores)):
            node_features = node_store.get("x")
            num_nodes = node_features.shape[0]
            n_edge_features = self.num_edge_types[node_type] * self.latent_dimension * self._n_scatter_reducers

            # define empty tensor that will store edge features
            all_edge_features = torch.zeros(size=(num_nodes, n_edge_features),
                                            device=node_features.device)

            relevant_edge_ids = [position
                                 for position, (_, _, dest_node_type) in enumerate(graph.edge_types)
                                 if dest_node_type == node_type]
            # look for all edges that have the current node type as destination

            edge_increment = self.latent_dimension * self._n_scatter_reducers
            for pos, edge_index in enumerate(relevant_edge_ids):
                edge_features = graph.edge_stores[edge_index].get("edge_attr")
                src_indices, dest_indices = graph.edge_stores[edge_index].get("edge_index")
                scatter_edge_indices = self._get_edge_indices((src_indices, dest_indices))
                aggr_features = self.multiscatter(features=edge_features, indices=scatter_edge_indices,
                                                  dim=0, dim_size=num_nodes)
                # aggr_features now has shape (num_nodes, latent_dimension * n_scatter_reducers)
                # this inner loop is across edge types
                all_edge_features[:, pos * edge_increment:(pos + 1) * edge_increment] = aggr_features

            # we need to repeat that with each type of aggregation. aggr_features

            aggregated_features = torch.cat([node_features, all_edge_features], 1)
            aggregated_features = self._maybe_global(aggregated_features, graph, node_type)

            # write update to graph
            node_store["x"] = self._mlps[node_type](aggregated_features)

    def _concat_global(self, features, graph, node_type):
        batch = graph[node_type].batch
        return torch.cat([features, graph.u[batch]], 1)


class HeterogeneousGlobalModule(HeterogeneousMetaModule):
    def forward(self, graph: Batch):
        """
        computes the forward pass for the global module
        Args:
            graph: of type torch_geometric.data.Batch

        Returns: None, in-place operation

        """
        edge_feature_list = []  # stores the reduced edge features per edge type
        node_feature_list = []  # stores the reduced node features per node type
        for edge_type, edge_store in zip(graph.edge_types, graph.edge_stores):
            edge_attr = edge_store.get("edge_attr")
            edge_indices = edge_store.get("edge_index")
            source_indices, _ = edge_indices
            source_node_type, _, _ = edge_type
            indices = graph[source_node_type].batch
            # indices assigns each node to a graph in the batch of graphs. We use this to aggregate over the edges
            # by querying this for the source node of each edge
            reduced_edge_features = self.multiscatter(features=edge_attr,
                                                      indices=indices[source_indices],  # query source node of each edge
                                                      dim=0,
                                                      dim_size=graph.u.shape[0])
            edge_feature_list.append(reduced_edge_features)

        for node_type, node_store in zip(graph.node_types, graph.node_stores):
            node_attr = node_store.get("x")
            reduced_node_features = self.multiscatter(features=node_attr,
                                                      indices=graph[node_type].batch,
                                                      dim=0,
                                                      dim_size=graph.u.shape[0])
            node_feature_list.append(reduced_node_features)

        aggregated_edge_features = torch.cat(edge_feature_list, 1)
        aggregated_node_features = torch.cat(node_feature_list, 1)

        aggregated_features = torch.cat([aggregated_node_features, aggregated_edge_features, graph.u], 1)
        graph.u = self._mlps["u"](aggregated_features)  # write back to graph
