import abc
from typing import Dict, Any, List, Callable, Optional

import torch
from torch_geometric.data.batch import Batch

from hmpn.abstract.abstract_modules import AbstractMetaModule
from hmpn.common.latent_mlp import LatentMLP


class HomogeneousMetaModule(AbstractMetaModule, abc.ABC):
    """
    Base class for the homogeneous modules used in the GNN.
    They are used for updating node-, edge-, and global features.
    """

    def __init__(self, *,
                 in_features: int,
                 latent_dimension: int,
                 mlp_config: Optional[Dict[str, Any]],
                 scatter_reducers: List[Callable],
                 create_mlp: bool = True,
                 use_global_features: bool = False):
        """
        Args:
            in_features: Number of input features
            latent_dimension: Dimensionality of the internal layers of the mlp
            mlp_config: Dictionary specifying the way that the MLPs for each update should look like
            scatter_reducers: How to aggregate over the nodes/edges/globals. Can for example be [torch.scatter_mean]
            create_mlp: Whether to create an MLP or not
            use_global_features: whether global features are used
        """
        super().__init__(scatter_reducers=scatter_reducers)
        if create_mlp:
            self._mlp = LatentMLP(in_features=in_features,
                                  latent_dimension=latent_dimension,
                                  config=mlp_config)
        else:
            self._mlp = None

        self.use_global_features = use_global_features
        if use_global_features:
            self._maybe_concat_global = self._concat_global
        else:
            self._maybe_concat_global = lambda x, y: x

    def _concat_global(self, features, graph):
        raise NotImplementedError(f"Module {type(self)} needs to implement _concat_global(self,features,graph)")

    @property
    def mlp(self):
        assert self._mlp is not None, "MLP is not initialized"
        return self._mlp


class HomogeneousEdgeModule(HomogeneousMetaModule):
    """
    Module for computing edge updates of a step on a homogeneous message passing GNN. Edge inputs are concatenated:
    Its own edge features, the features of the two participating nodes and optionally,
    global features are also concatenated to the input.
    """

    def __init__(self, *,
                 latent_dimension: int,
                 mlp_config: Dict[str, Any],
                 scatter_reducers: List[Callable],
                 use_global_features: bool = False):
        in_features = 3 * latent_dimension  # edge features, and the two participating nodes
        if use_global_features:
            in_features += latent_dimension
        super(HomogeneousEdgeModule, self).__init__(in_features=in_features,
                                                    latent_dimension=latent_dimension,
                                                    mlp_config=mlp_config,
                                                    scatter_reducers=scatter_reducers,
                                                    create_mlp=True,
                                                    use_global_features=use_global_features)

    def forward(self, graph: Batch):
        """
        Compute edge updates for the edges of the Module for homogeneous graphs in-place.
        An updated representation of the edge attributes for all edge_types is written back into the graph
        Args:
            graph: Data object of pytorch geometric. Represents a batch of homogeneous graphs
        Returns: None
        """
        source_indices, dest_indices = graph.edge_index
        edge_source_nodes = graph.x[source_indices]
        edge_dest_nodes = graph.x[dest_indices]

        aggregated_features = torch.cat([edge_source_nodes, edge_dest_nodes, graph.edge_attr], 1)
        aggregated_features = self._maybe_concat_global(aggregated_features, graph)

        graph.__setattr__("edge_attr", self.mlp(aggregated_features))

    def _concat_global(self, aggregated_features, graph):
        """
        computation and concatenation of global features
        Args:
            aggregated_features: so-far aggregated features
            graph: pytorch_geometric.data.Batch object

        Returns: aggregated_features with the global features appended
        """
        source_indices, _ = graph.edge_index
        global_features = graph.u[graph.batch[source_indices]]
        return torch.cat([aggregated_features, global_features], 1)


class HomogeneousMessagePassingNodeModule(HomogeneousMetaModule):
    """
    Module for computing node updates of a step on a homogeneous message passing GNN. Node inputs are concatenated:
    Its own Node features, the reduced features of all incoming edges and optionally,
    global features are also concatenated to the input.
    """

    def __init__(self, *,
                 latent_dimension: int,
                 mlp_config: Dict[str, Any],
                 scatter_reducers: List[Callable],
                 use_global_features: bool = False,
                 flip_edges_for_nodes: bool = False):
        """
        Module responsible for the node update of a step on a homogeneous message passing GNN.
        Node inputs are concatenated: Its own Node features, the reduced features of all incoming edges and optionally,
        global features are also concatenated to the input.
        Args:
            latent_dimension: Dimensionality of the internal layers of the mlp
            mlp_config: Dictionary specifying the way that the MLPs for each update should look like
            scatter_reducers: How to aggregate over the nodes/edges/globals. Can for example be [torch.scatter_mean]
            use_global_features: whether global features are used
            flip_edges_for_nodes: whether to flip the edge indices for the aggregation of edge features
        """
        # use self.mlp for the update
        n_scatter_ops = len(scatter_reducers)
        in_features = latent_dimension * (1 + n_scatter_ops)  # node and aggregated incoming edge features
        if use_global_features:
            in_features += latent_dimension
        create_mlp = True

        super(HomogeneousMessagePassingNodeModule, self).__init__(in_features=in_features,
                                                                  latent_dimension=latent_dimension,
                                                                  mlp_config=mlp_config,
                                                                  scatter_reducers=scatter_reducers,
                                                                  create_mlp=create_mlp,
                                                                  use_global_features=use_global_features)

        # use the source indices for feature aggregation if edges shall be flipped
        if flip_edges_for_nodes:
            self._get_edge_indices = lambda src_and_dest_indices: src_and_dest_indices[0]
        else:
            self._get_edge_indices = lambda src_and_dest_indices: src_and_dest_indices[1]

    def forward(self, graph: Batch):
        """
        Compute node updates for the nodes of the Module for homogeneous graphs in-place. An updated representation of
        the node attributes for all node_types is written back into the graph. Uses a message passing approach with
        potential global features and multiple scatter reduce operations.
        Returns:

        """
        src_indices, dest_indices = graph.edge_index
        scatter_edge_indices = self._get_edge_indices((src_indices, dest_indices))

        aggregated_edge_features = self.multiscatter(features=graph.edge_attr, indices=scatter_edge_indices,
                                                     dim=0, dim_size=graph.x.shape[0])
        aggregated_features = torch.cat([graph.x, aggregated_edge_features], dim=1)
        aggregated_features = self._maybe_concat_global(aggregated_features, graph)

        # update
        graph.__setattr__("x", self.mlp(aggregated_features))

    def _concat_global(self, aggregated_features, graph):
        """
        computation and concatenation of global features
        Args:
            aggregated_features: so-far aggregated features
            graph: pytorch_geometric.data.Batch object

        Returns: aggregated_features with the global features appended
        """
        global_features = graph.u[graph.batch]
        return torch.cat([aggregated_features, global_features], dim=1)


class HomogeneousGatNodeModule(torch.nn.Module):
    """
    Module for computing node updates of a step on a homogeneous message passing GNN. Node inputs are concatenated:
    Its own Node features, the reduced features of all incoming edges and optionally,
    global features are also concatenated to the input.
    """

    def __init__(self, *,
                 latent_dimension: int,
                 use_global_features: bool = False,
                 flip_edges_for_nodes: bool = False,
                 heads: int = 4,
                 ):
        """
        Module responsible for the node update of a step on a homogeneous GAT with edge updates.
        Args:
            latent_dimension: Dimensionality of the internal layers of the mlp
            use_global_features: whether global features are used
            flip_edges_for_nodes: whether to flip the edge indices for the aggregation of edge features
            heads: number of attention heads for the GAT
        """
        super(HomogeneousGatNodeModule, self).__init__()
        from torch_geometric.nn import GATv2Conv
        in_channels = latent_dimension
        if use_global_features:
            in_channels += latent_dimension
        self._gat = GATv2Conv(
            in_channels=in_channels,
            out_channels=int(latent_dimension / heads),
            heads=heads,
            add_self_loops=False,
            edge_dim=latent_dimension,
        )

        self.use_global_features = use_global_features
        if use_global_features:
            self._maybe_concat_global = self._concat_global
        else:
            self._maybe_concat_global = lambda x, y: x

        # use the source indices for feature aggregation if edges shall be flipped
        if flip_edges_for_nodes:
            self._maybe_flip_edges = lambda x: x.flip(0)
        else:
            self._maybe_flip_edges = lambda x: x

    def forward(self, graph: Batch):
        """
        Compute node updates for the nodes of the Module for homogeneous graphs in-place. An updated representation of
        the node attributes for all node_types is written back into the graph. Uses a GAT approach with
        potential global features
        Returns:

        """
        node_input = graph.x
        node_input = self._maybe_concat_global(node_input, graph)
        graph_edges = graph.edge_index
        graph_edges = self._maybe_flip_edges(graph_edges)
        graph.__setattr__("x", self._gat(node_input, graph_edges, edge_attr=graph.edge_attr))

    def _concat_global(self, aggregated_features, graph):
        """
        computation and concatenation of global features
        Args:
            aggregated_features: so-far aggregated features
            graph: pytorch_geometric.data.Batch object

        Returns: aggregated_features with the global features appended
        """
        global_features = graph.u[graph.batch]
        return torch.cat([aggregated_features, global_features], dim=1)


class HomogeneousGlobalModule(HomogeneousMetaModule):
    """
    Module for computing updates of global features of a step on a homogeneous message passing GNN.
    Global feature network inputs are concatenated: Its own global features, the reduced features of all edges,
    and the reduced features of all nodes.
    """

    def __init__(self, *,
                 latent_dimension: int,
                 mlp_config: Dict[str, Any],
                 scatter_reducers: List[Callable],
                 use_global_features: bool = True):
        assert use_global_features, "Global features must be used for the global module"
        n_scatter_ops = len(scatter_reducers)
        in_features = latent_dimension * (2 * n_scatter_ops + 1)
        super(HomogeneousGlobalModule, self).__init__(in_features=in_features,
                                                      latent_dimension=latent_dimension,
                                                      mlp_config=mlp_config,
                                                      scatter_reducers=scatter_reducers,
                                                      create_mlp=True,
                                                      use_global_features=use_global_features)

    def forward(self, graph: Batch):
        """
        Compute updates for the global feature vector
            graph: Batch object of pytorch_geometric.data, represents a batch of homogeneous graphs
        Returns: None. in-place operation.
        """
        reduced_node_features = self.multiscatter(features=graph.x,
                                                  indices=graph.batch,
                                                  dim=0,
                                                  dim_size=graph.u.shape[0])
        source_indices, _ = graph.edge_index
        reduced_edge_features = self.multiscatter(features=graph.edge_attr,
                                                  indices=graph.batch[source_indices],
                                                  dim=0,
                                                  dim_size=graph.u.shape[0])
        aggregated_features = torch.cat([reduced_edge_features, reduced_node_features, graph.u], dim=1)
        graph.__setattr__("u", self.mlp(aggregated_features))
