from typing import Union, Any, Callable, Iterable

import torch
from torch import nn
from torch_geometric.data.batch import Batch
from torch_geometric.data.hetero_data import HeteroData

from hmpn.abstract.abstract_step import AbstractStep
from hmpn.common.hmpn_util import noop
from hmpn.heterogeneous2.heterogeneous2_modules import (
    Heterogeneous2EdgeModule,
    Heterogeneous2NodeModule,
    Heterogeneous2GlobalModule,
)


class Heterogeneous2Step(AbstractStep):
    """
    Defines a single Message Passing Step that takes a heterogeneous observation graph and updates its node and edge
    features using different modules (Edge, Node, Global).
    It first updates the edge-features. The node-features are updated next using the new edge-features. Finally,
    it updates the global features using the new edge- & node-features. The updates are done through MLPs.
    """

    def __init__(
        self,
        latent_dimension: int,
        scatter_reducers: Union[Callable, list[Callable]],
        stack_config: dict[str, Any],
        use_global_features: bool,
        flip_edges_for_nodes: bool,
        active_edges: Iterable[tuple[str]],
    ):
        """
        Initializes the HeterogeneousStep.

        Args:
            latent_dimension: Dimension of the latent space.
            scatter_reducers: list of functions from torch_scatter to use for scatter operations.
            stack_config: Dictionary of stack configuration.
            use_global_features: Whether to use global features or not.
            flip_edges_for_nodes:
            active_edges:
        """
        super().__init__(
            stack_config=stack_config,
            latent_dimension=latent_dimension,
            use_global_features=use_global_features,
        )

        n_scatter_ops = len(scatter_reducers)
        mlp_config = stack_config.get("mlp")

        self.active_edges = active_edges
        self.active_nodes = []
        for edge_type in active_edges:
            self.active_nodes += [edge_type[0], edge_type[2]]
        self.active_nodes = set(self.active_nodes)

        if use_global_features:
            global_dim = latent_dimension
            in_dim = latent_dimension * (
                (len(self.active_edges) + len(self.active_nodes)) * n_scatter_ops + 1
            )
            self.global_module = Heterogeneous2GlobalModule(
                in_dim=in_dim,
                out_dim=latent_dimension,
                mlp_config=mlp_config,
                scatter_reducers=scatter_reducers,
                use_global_features=use_global_features,
            )
            self.maybe_global = self.global_module
        else:
            global_dim = 0
            self.global_module = None
            self.maybe_global = noop

        active_node_dims = {
            name: global_dim + latent_dimension for name in self.active_nodes
        }
        active_edge_dims = {
            name: global_dim + 3 * latent_dimension for name in active_edges
        }

        for _, _, dest in active_edges:
            active_node_dims[dest] += latent_dimension * n_scatter_ops

        # edge module
        self.edge_module = Heterogeneous2EdgeModule(
            in_dims=active_edge_dims,
            out_dim=latent_dimension,
            mlp_config=mlp_config,
            scatter_reducers=scatter_reducers,
            use_global_features=use_global_features,
        )

        # node module
        self.node_module = Heterogeneous2NodeModule(
            in_dims=active_node_dims,
            out_dim=latent_dimension,
            mlp_config=mlp_config,
            scatter_reducers=scatter_reducers,
            use_global_features=use_global_features,
            flip_edges_for_nodes=flip_edges_for_nodes,
        )

        def get_ln():
            return (
                nn.LayerNorm(latent_dimension) if self.use_layer_norm else nn.Identity()
            )

        self._node_lns = nn.ModuleDict({name: get_ln() for name in self.active_nodes})
        self._edge_lns = nn.ModuleDict({str(name): get_ln() for name in active_edges})

        self.old_node_stores = []
        self.old_edge_stores = []
        self.reset_parameters()

    def forward(self, graph: HeteroData):
        subgraph = graph.edge_type_subgraph(self.active_edges)
        subgraph = subgraph.node_type_subgraph(self.active_nodes)
        if ("x", "level0", "x") in self.active_edges:
            super().forward(subgraph)

            for node_type in self.active_nodes:
                graph[node_type].x = subgraph[node_type].x
        else:
            subgraph_dict: dict[str, torch.Tensor] = {}
            for node_type in self.active_nodes:
                x = graph[node_type].x
                subgraph_dict[node_type] = torch.zeros(
                    x.shape[0], device=x.device, dtype=torch.bool
                )

            for (src_type, _, dest_type), edge_store in subgraph.edge_items():
                src_indices, dest_indices = edge_store.edge_index
                subgraph_dict[src_type].index_fill_(0, src_indices, True)
                subgraph_dict[dest_type].index_fill_(0, dest_indices, True)

            # generating_subgraph
            # faster than subgraph = subgraph.subgraph(subgraph_dict)
            reindex_dict = {}
            for node_type, mask in subgraph_dict.items():
                subgraph[node_type].x = subgraph[node_type].x[mask]
                reindex_dict[node_type] = torch.cumsum(mask, 0) - 1

            for (src_type, _, dest_type), edge_store in subgraph.edge_items():
                reindex_src = reindex_dict[src_type]
                reindex_dest = reindex_dict[dest_type]
                edge_index = edge_store.edge_index.clone()
                new_edge_index = torch.empty_like(edge_index)
                new_edge_index[0] = reindex_src.index_select(0, edge_index[0])
                new_edge_index[1] = reindex_dest.index_select(0, edge_index[1])
                edge_store.edge_index = new_edge_index

            super().forward(subgraph)

            for node_type, mask in subgraph_dict.items():
                graph[node_type].x[mask] += (
                    subgraph[node_type].x - graph[node_type].x[mask]
                )

        for src, name, dest in self.active_edges:
            graph[src, name, dest].edge_attr = subgraph[src, name, dest].edge_attr

    def _store_nodes(self, graph: Batch):
        self.old_node_stores = [
            node_store.x.clone() for node_store in graph.node_stores
        ]

    def _store_edges(self, graph: Batch):
        self.old_edge_stores = [
            edge_store.edge_attr.clone() for edge_store in graph.edge_stores
        ]

    def _add_node_residual(self, graph: Batch):
        for node_store, old_node_store in zip(graph.node_stores, self.old_node_stores):
            node_store.x += old_node_store

    def _add_edge_residual(self, graph: Batch):
        for edge_store, old_edge_store in zip(graph.edge_stores, self.old_edge_stores):
            edge_store.edge_attr += old_edge_store

    def _node_layer_norm(self, graph: Batch) -> None:
        for node_type, node_store in graph.node_items():
            if node_type in self._node_lns:
                node_store.x = self._node_lns[node_type](node_store.x)

    def _edge_layer_norm(self, graph: Batch) -> None:
        for edge_type, edge_store in graph.edge_items():
            if str(edge_type) in self._edge_lns:
                edge_store.edge_attr = self._edge_lns[str(edge_type)](
                    edge_store.edge_attr
                )
