import numpy as np
import torch as tc
import torch.nn as nn


class LocalPooler(nn.Module):
    def __init__(self, direct=False, reduce="sum"):
        """
        Args:
            direct      Directed edges
            reduce      Aggregation type
        """
        super(LocalPooler, self).__init__()
        self.direct = direct
        assert (
            reduce in ["sum", "avg"]
        ), f"Local pooling aggregation type (reduce) {reduce} is not implemented yet."
        self.reduce = reduce

    def forward(self, n_obj: list[int], msg: tc.Tensor, edge_index: tc.Tensor) -> tc.Tensor:
        """
        Pools directional messages to each node.

        Conceptually computes
        node_msg[n] = sum of msg[:, e, :] for all e where dst[e] == n

        Args:
            n_obj       List of nodes (e.g. for a grid of nodes you can use [x, y])
                        for undefined geometries use list[n_nodes] where n_nodes is
                        the total number of nodes in the system.
            msg of shape (n_samples, n_edges_total, width)
            pool    either 'sum' or 'avg'
        Returns:
            node_msg of shape (n_samples, *n_obj, width)
        """
        src, dst = edge_index                  # (n_edges,) both
        n_samples, n_edges_total, width = msg.shape
        n_nodes = int(np.prod(n_obj))

        if self.direct:
            # Route each edge to dst
            routing_dst = dst
        else:
            # We double messages: first n_edges are src->dst, second n_edges are dst->src
            routing_dst = tc.cat([dst, src], dim=0) # (2 * n_edges,)

        # Step 1: Build (n_edges_total, n_nodes) routing matrix
        M = tc.zeros(n_edges_total, n_nodes, device=msg.device, dtype=msg.dtype)
        M[tc.arange(n_edges_total), routing_dst] = 1.0

        # Step 2: Pool locally
        if self.reduce == "sum":
            pooled_msg = self.reduce_sum(msg, M)
        elif self.reduce == "avg":
            pooled_msg = self.reduce_avg(n_obj, msg, M, src, dst)
        else: raise NotImplementedError()

        # Step 4: Reshape
        return pooled_msg.view(n_samples, *n_obj, width)

    def reduce_sum(self, msg: tc.Tensor, M: tc.Tensor) -> tc.Tensor:
        """
        Realizes a local message aggregation (incoming messages, from edges), using sum.

        Args:
            msg        Incoming messages, i.e., messages gathered on the edges ready to
                            be transmitted of shape (n_samples, n_edges_total, n_features)
            M               'Routing' matrix of shape (n_samples, n_edges_total, n_nodes)
                            where n_total_edges is the number of total transmissions,
                            which can be 2*n_total_edges if the edges are directed.
        Returns:
            pooled_msg      Aggregated messages on the nodes
        """
        # Pool with einsum (n_samples, n_edges_total, width) @ (n_samples, n_edges_total, n_nodes)
        # -> (n_samples, n_nodes, width)
        pooled_msg = tc.einsum("pew,en->pnw", msg, M)
        return pooled_msg

    def reduce_avg(self, n_obj: list[int], msg: tc.Tensor, M: tc.Tensor, src, dst) -> tc.Tensor:
        """
        Realizes a local message aggregation (incoming messages, from edges), using average.

        Args:
            n_obj       List of nodes (e.g. for a grid of nodes you can use [x, y])
                        for undefined geometries use list[n_nodes] where n_nodes is
                        the total number of nodes in the system.
            msg        Incoming messages, i.e., messages gathered on the edges ready to
                            be transmitted of shape (n_samples, n_edges_total, n_features)
            M               'Routing' matrix of shape (n_samples, n_edges_total, n_nodes)
                            where n_total_edges is the number of total transmissions,
                            which can be 2*n_total_edges if the edges are directed.
            src             Source node indices.
            dst             Destination node indices.
        Returns:
            pooled_msg      Aggregated messages on the nodes
        """
        n_nodes = int(np.prod(n_obj))
        pooled_msg = self.reduce_sum(msg, M)

        # count how many messages each node receives
        if self.direct:
            neighbor_count = tc.bincount(dst, minlength=n_nodes)
        else:
            neighbor_count = tc.bincount(tc.cat([dst, src], dim=0), minlength=n_nodes)

        neighbor_count = neighbor_count.clamp(min=1.0).to(pooled_msg.device)
        pooled_msg = pooled_msg / neighbor_count.view(1, -1, 1)

        return pooled_msg
