#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : edgepool.py
# Author : Anonymous1
# Email  : anonymous1@anon
#
# Distributed under terms of the MIT license.

# Modified from https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/pool/edge_pool.py

from collections import namedtuple

import time
import numpy as np
import torch
import torch.nn.functional as F
from torch_scatter import scatter_add, scatter_max
from torch_sparse import coalesce

from torch_geometric.utils import softmax


class EdgePooling(torch.nn.Module):
    r"""The edge pooling operator from the `"Towards Graph Pooling by Edge
    Contraction" <https://graphreason.github.io/papers/17.pdf>`_ and
    `"Edge Contraction Pooling for Graph Neural Networks"
    <https://arxiv.org/abs/1905.10990>`_ papers.

    In short, a score is computed for each edge.
    Edges are contracted iteratively according to that score unless one of
    their nodes has already been part of a contracted edge.

    To duplicate the configuration from the "Towards Graph Pooling by Edge
    Contraction" paper, use either
    :func:`EdgePooling.compute_edge_score_softmax`
    or :func:`EdgePooling.compute_edge_score_tanh`, and set
    :obj:`add_to_edge_score` to :obj:`0`.

    To duplicate the configuration from the "Edge Contraction Pooling for
    Graph Neural Networks" paper, set :obj:`dropout` to :obj:`0.2`.

    Args:
        in_channels (int): Size of each input sample.
        edge_score_method (function, optional): The function to apply
            to compute the edge score from raw edge scores. By default,
            this is the softmax over all incoming edges for each node.
            This function takes in a :obj:`raw_edge_score` tensor of shape
            :obj:`[num_nodes]`, an :obj:`edge_index` tensor and the number of
            nodes :obj:`num_nodes`, and produces a new tensor of the same size
            as :obj:`raw_edge_score` describing normalized edge scores.
            Included functions are
            :func:`EdgePooling.compute_edge_score_softmax`,
            :func:`EdgePooling.compute_edge_score_tanh`, and
            :func:`EdgePooling.compute_edge_score_sigmoid`.
            (default: :func:`EdgePooling.compute_edge_score_softmax`)
        dropout (float, optional): The probability with
            which to drop edge scores during training. (default: :obj:`0`)
        add_to_edge_score (float, optional): This is added to each
            computed edge score. Adding this greatly helps with unpool
            stability. (default: :obj:`0.5`)
    """

    unpool_description = namedtuple(
        "UnpoolDescription", ["edge_index", "cluster", "batch", "new_edge_score"]
    )

    def __init__(
        self, in_channels, edge_score_method=None, dropout=0, add_to_edge_score=0.5
    ):
        super().__init__()
        self.in_channels = in_channels
        if edge_score_method is None:
            edge_score_method = self.compute_edge_score_softmax
        self.compute_edge_score = edge_score_method
        self.add_to_edge_score = add_to_edge_score
        self.dropout = dropout

        self.lin = torch.nn.Linear(2 * in_channels, 1)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()

    @staticmethod
    def compute_edge_score_softmax(raw_edge_score, edge_index, num_nodes):
        return softmax(raw_edge_score, edge_index[1], num_nodes=num_nodes)

    @staticmethod
    def compute_edge_score_tanh(raw_edge_score, edge_index, num_nodes):
        return torch.tanh(raw_edge_score)

    @staticmethod
    def compute_edge_score_sigmoid(raw_edge_score, edge_index, num_nodes):
        return torch.sigmoid(raw_edge_score)

    def forward(self, x, edge_index, batch):
        r"""Forward computation which computes the raw edge score, normalizes
        it, and merges the edges.

        Args:
            x (Tensor): The node features.
            edge_index (LongTensor): The edge indices.
            batch (LongTensor): Batch vector
                :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
                each node to a specific example.

        Return types:
            * **x** *(Tensor)* - The pooled node features.
            * **edge_index** *(LongTensor)* - The coarsened edge indices.
            * **batch** *(LongTensor)* - The coarsened batch vector.
            * **unpool_info** *(unpool_description)* - Information that is
              consumed by :func:`EdgePooling.unpool` for unpooling.
        """
        e = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=-1)
        e = self.lin(e).view(-1)
        e = F.dropout(e, p=self.dropout, training=self.training)
        e = self.compute_edge_score(e, edge_index, x.size(0))
        e = e + self.add_to_edge_score

        x, edge_index, batch, unpool_info = self.__merge_edges__(
            x, edge_index, batch, e
        )

        return x, edge_index, batch, unpool_info

    def __merge_edges__(self, x, edge_index, batch, edge_score):
        nr_nodes = x.size(0)
        nodes_remaining = np.ones(nr_nodes, dtype="bool")
        cluster = np.zeros(nr_nodes, dtype="int")
        edge_argsort = torch.argsort(edge_score, descending=True)
        # fetch as list (to cpu) to speed up the for loop.
        source_list = edge_index[0, edge_argsort].tolist()
        target_list = edge_index[1, edge_argsort].tolist()
        sorted_edges = edge_argsort.tolist()
        num_edges = len(sorted_edges)

        # Iterate through all edges, selecting it if it is not incident to
        # another already chosen edge.
        m = 0
        new_edge_indices = []
        for i in range(num_edges):
            source = source_list[i]
            target = target_list[i]
            if nodes_remaining[source] and nodes_remaining[target]:
                new_edge_indices.append(sorted_edges[i])

                cluster[source] = m
                nodes_remaining[source] = False
                if source != target:
                    cluster[target] = m
                    nodes_remaining[target] = False
                m += 1

        # The remaining nodes are simply kept.
        for i in range(nr_nodes):
            if nodes_remaining[i]:
                cluster[i] = m
                m += 1
        cluster = torch.from_numpy(cluster).to(x.device)

        # We compute the new features as an addition of the old ones.
        new_x = scatter_add(x, cluster, dim=0, dim_size=m)
        new_edge_score = edge_score[new_edge_indices]
        if nodes_remaining.any():
            remaining_score = x.new_ones((new_x.size(0) - len(new_edge_indices),))
            new_edge_score = torch.cat([new_edge_score, remaining_score])
        new_x = new_x * new_edge_score.view(-1, 1)

        N = new_x.size(0)
        new_edge_index, _ = coalesce(cluster[edge_index], None, N, N)

        new_batch = x.new_empty(new_x.size(0), dtype=torch.long)
        new_batch = new_batch.scatter_(0, cluster, batch)

        unpool_info = self.unpool_description(
            edge_index=edge_index,
            cluster=cluster,
            batch=batch,
            new_edge_score=new_edge_score,
        )
        return new_x, new_edge_index, new_batch, unpool_info

    def unpool(self, x, unpool_info):
        r"""Unpools a previous edge pooling step.

        For unpooling, :obj:`x` should be of same shape as those produced by
        this layer's :func:`forward` function. Then, it will produce an
        unpooled :obj:`x` in addition to :obj:`edge_index` and :obj:`batch`.

        Args:
            x (Tensor): The node features.
            unpool_info (unpool_description): Information that has
                been produced by :func:`EdgePooling.forward`.

        Return types:
            * **x** *(Tensor)* - The unpooled node features.
            * **edge_index** *(LongTensor)* - The new edge indices.
            * **batch** *(LongTensor)* - The new batch vector.
        """

        new_x = x / unpool_info.new_edge_score.view(-1, 1)
        new_x = new_x[unpool_info.cluster]
        return new_x, unpool_info.edge_index, unpool_info.batch

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.in_channels})"


# *Example of using this pooling when given dgl.DGLGraph
# nodes, edges = graph.nodes(), graph.edges()
# edges = torch.stack(edges, dim=0)
# batch = torch.zeros_like(nodes)
# new_h, new_edges, new_batch, unpool_info = self.pooling(h, edges, batch)
# new_graph = dgl.graph((new_edges[0], new_edges[1]))
# new_h = layer(new_graph, new_h)
# new_h, _, _ = self.pooling.unpool(new_h, unpool_info)
