from __future__ import annotations

from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor
import torch.nn.functional as F

from torch_geometric.utils import (
    batched_negative_sampling,
    negative_sampling,
    dense_to_sparse,
    to_dense_batch,
    to_dense_adj,
    to_torch_coo_tensor,
)
from torch_geometric.nn.dense import dense_diff_pool, dense_mincut_pool
from torch_geometric.nn.dense.mincut_pool import _rank3_trace
from torch_geometric.nn.models.mlp import Linear
from torch_geometric.nn.resolver import activation_resolver

import numpy as np

import math
from math import floor

import sys

sys.path.insert(1, "../src/")

from neural_sbm.gcn import GCN
from neural_sbm.gcn2 import GCN2
from neural_sbm.mlp import MLP
from neural_sbm.dgcnn import DGCNN
# from neural_sbm.tvgnn import TVGNN

from neural_sbm.utils import get_num_params, positive_quadratic_root


def dense_dmon_pool(
    x: Tensor,
    adj: Tensor,
    s: Tensor,
    mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
    r"""
    Args:
        x (torch.Tensor): Node feature tensor
            :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with
            batch-size :math:`B`, (maximum) number of nodes :math:`N` for
            each graph, and feature dimension :math:`F`.
            Note that the cluster assignment matrix
            :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}` is
            being created within this method.
        adj (torch.Tensor): Adjacency tensor
            :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
        s (torch.Tensor): Assignment tensor
            :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}`
            with number of clusters :math:`C`.
            The softmax does not have to be applied before-hand, since it is
            executed within this method.
        mask (torch.Tensor, optional): Mask matrix
            :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
            the valid nodes for each graph. (default: :obj:`None`)

    :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`,
        :class:`torch.Tensor`, :class:`torch.Tensor`)
    """
    x = x.unsqueeze(0) if x.dim() == 2 else x
    adj = adj.unsqueeze(0) if adj.dim() == 2 else adj

    s = torch.softmax(s, dim=-1)

    (batch_size, num_nodes, _), C = x.size(), s.size(-1)

    if mask is None:
        mask = torch.ones(batch_size, num_nodes, dtype=torch.bool, device=x.device)

    mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
    x, s = x * mask, s * mask

    out = F.selu(torch.matmul(s.transpose(1, 2), x))
    out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)

    # Spectral loss:
    degrees = torch.einsum("ijk->ij", adj)  # B X N
    degrees = degrees.unsqueeze(-1) * mask  # B x N x 1
    degrees_t = degrees.transpose(1, 2)  # B x 1 x N

    m = torch.einsum("ijk->i", degrees) / 2  # B
    m_expand = m.view(-1, 1, 1).expand(-1, C, C)  # B x C x C

    ca = torch.matmul(s.transpose(1, 2), degrees)  # B x C x 1
    cb = torch.matmul(degrees_t, s)  # B x 1 x C

    normalizer = torch.matmul(ca, cb) / 2 / m_expand
    decompose = out_adj - normalizer
    spectral_loss = -_rank3_trace(decompose) / 2 / m
    spectral_loss = spectral_loss.mean()

    # Orthogonality regularization:
    ss = torch.matmul(s.transpose(1, 2), s)
    i_s = torch.eye(C).type_as(ss)
    ortho_loss = torch.norm(
        ss / torch.norm(ss, dim=(-1, -2), keepdim=True) - i_s / torch.norm(i_s),
        dim=(-1, -2),
    )
    ortho_loss = ortho_loss.mean()

    # Cluster loss:
    cluster_size = torch.einsum("ijk->ik", s)  # B x C
    cluster_loss = torch.norm(input=cluster_size, dim=1)
    cluster_loss = cluster_loss / mask.sum(dim=1) * torch.norm(i_s) - 1
    cluster_loss = cluster_loss.mean()

    EPS = 1e-15

    # Fix and normalize coarsened adjacency matrix:
    ind = torch.arange(C, device=out_adj.device)
    out_adj[:, ind, ind] = 0
    d = torch.einsum("ijk->ij", out_adj)
    d = torch.sqrt(d)[:, None] + EPS
    out_adj = (out_adj / d) / d.transpose(1, 2)

    return out, out_adj, spectral_loss, cluster_loss


class AsymCheegerCutPool(torch.nn.Module):
    r"""
    The asymmetric cheeger cut pooling layer from the `"Total Variation Graph Neural Networks"
    <https://arxiv.org/abs/2211.06218>`_ paper.

    Args:
        k (int):
            Number of clusters or output nodes
        return_selection (bool):
            Whether to return selection matrix. Cannot not  be False
            if `return_pooled_graph` is False. (default: :obj:`False`)
        return_pooled_graph (bool):
            Whether to return pooled node features and adjacency.
            Cannot be False if `return_selection` is False. (default: :obj:`True`)
        totvar_coeff (float):
            Coefficient for graph total variation loss component. (default: :obj:`1.0`)
        balance_coeff (float):
            Coefficient for asymmetric norm loss component. (default: :obj:`1.0`)
    """

    def __init__(
        self,
        k: int,
        return_selection: bool = False,
        return_pooled_graph: bool = True,
        totvar_coeff: float = 1.0,
        balance_coeff: float = 1.0,
    ):
        super().__init__()

        if not return_selection and not return_pooled_graph:
            raise ValueError(
                "return_selection and return_pooled_graph can not both be False"
            )

        self.k = k
        self.return_selection = return_selection
        self.return_pooled_graph = return_pooled_graph
        self.totvar_coeff = totvar_coeff
        self.balance_coeff = balance_coeff

    def forward(
        self,
        x: Tensor,
        adj: Tensor,
        s: Tensor,
        mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
        r"""
        Args:
            x (Tensor):
                Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`
                with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph,
                and feature dimension :math:`F`. Note that the cluster assignment matrix
                :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}` is
                being created within this method.
            adj (Tensor):
                Adjacency tensor :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
            s (torch.Tensor): Assignment tensor
                :math:`\mathbf{S} \in \mathbb{R}^{B \times N \times C}`
                with number of clusters :math:`C`.
                The softmax does not have to be applied before-hand, since it is
                executed within this method.
            mask (BoolTensor, optional):
                Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}`
                indicating the valid nodes for each graph. (default: :obj:`None`)

        :rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`,
            :class:`Tensor`, :class:`Tensor`, :class:`Tensor`)
        """
        x = x.unsqueeze(0) if x.dim() == 2 else x
        adj = adj.unsqueeze(0) if adj.dim() == 2 else adj

        s = torch.softmax(s, dim=-1)

        batch_size, n_nodes, _ = x.size()

        if mask is not None:
            mask = mask.view(batch_size, n_nodes, 1).to(x.dtype)
            x, s = x * mask, s * mask

        # Pooled features and adjacency
        if self.return_pooled_graph:
            x_pool = torch.matmul(s.transpose(1, 2), x)
            adj_pool = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)

        # Total variation loss
        tv_loss = self.totvar_coeff * torch.mean(self.totvar_loss(adj, s))

        # Balance loss
        bal_loss = self.balance_coeff * torch.mean(self.balance_loss(s))

        if self.return_selection and self.return_pooled_graph:
            return s, x_pool, adj_pool, tv_loss, bal_loss
        elif self.return_selection and not self.return_pooled_graph:
            return s, tv_loss, bal_loss
        else:
            return x_pool, adj_pool, tv_loss, bal_loss

    def totvar_loss(self, adj, s):
        l1_norm = torch.sum(torch.abs(s[..., None, :] - s[:, None, ...]), dim=-1)

        loss = torch.sum(adj * l1_norm, dim=(-1, -2))

        # Normalize loss
        n_edges = torch.count_nonzero(adj, dim=(-1, -2))
        loss *= 1 / (2 * n_edges)

        return loss

    def balance_loss(self, s):
        n_nodes = s.size()[-2]

        # k-quantile
        idx = int(math.floor(n_nodes / self.k))
        quant = torch.sort(s, dim=-2, descending=True)[0][:, idx, :]  # shape [B, K]

        # Asymmetric l1-norm
        loss = s - torch.unsqueeze(quant, dim=1)
        loss = (loss >= 0) * (self.k - 1) * loss + (loss < 0) * loss * -1
        loss = torch.sum(loss, dim=(-1, -2))  # shape [B]
        loss = 1 / (n_nodes * (self.k - 1)) * (n_nodes * (self.k - 1) - loss)

        return loss


def smart_teleport(A, alpha=0.15, iter=1000):
    # build the transition matrix
    T = torch.nan_to_num(A.T * torch.sum(A, 1).to_dense() ** (-1), nan=0.0).T.to(
        device=A.device
    )

    # distribution according to nodes' in-degrees
    e_v = (torch.sum(A, dim=0) / torch.sum(A)).to_dense().to(device=A.device)

    # calculate the flow distribution with a power iteration
    # p = (1/len(T) * torch.ones(len(T))).to(device = device)
    p = e_v
    for _ in range(iter):
        p = alpha * e_v + (1 - alpha) * p @ T

    # make the flow matrix for minimising the map equation
    F = alpha * A / torch.sum(A) + (1 - alpha) * (p * T.T).T

    return F, p


class NeuromapPooling(torch.nn.Module):
    r"""This criterion computes the map equation codelength for an undirected or directed weighted graph.

    Args:
        A (torch.Tensor): (Unnormalized) Adjacency matrix of the weighted graph.
        epsilon (float, optional): Small epsilon to ensure differentiability of logs.

    """

    def __init__(
        self,
        adj: Tensor,
        epsilon: float = 1e-8,
    ):
        super().__init__()

        self.epsilon = epsilon
        self.F, self.p = smart_teleport(adj)

    def forward(
        self,
        x: Tensor,
        s: Tensor,
        eps: float = 1e-8,
    ):
        s = s + eps

        out_adj = s.T @ self.F @ s

        diag = torch.sparse_coo_tensor(
            indices=[range(len(out_adj)), range(len(out_adj))],
            values=torch.diag(out_adj),
            size=out_adj.shape,
        ).to(device=s.device)

        e1 = torch.sum(out_adj) - torch.trace(out_adj)
        e2 = torch.sum(out_adj - diag, 1)
        e3 = self.p
        e4 = torch.sum(out_adj, 1) + torch.sum(out_adj.T - diag, 1)

        e1 = torch.sum(e1 * torch.nan_to_num(torch.log2(e1), nan=0.0))
        e2 = torch.sum(e2 * torch.nan_to_num(torch.log2(e2), nan=0.0))
        e3 = torch.sum(e3 * torch.nan_to_num(torch.log2(e3), nan=0.0))
        e4 = torch.sum(e4 * torch.nan_to_num(torch.log2(e4), nan=0.0))

        map_equation_loss = e1 - 2 * e2 - e3 + e4

        out = torch.matmul(s.T, x)

        return out, out_adj, map_equation_loss


# https://github.com/SherylHYX/pytorch_geometric_signed_directed/blob/main/torch_geometric_signed_directed/utils/directed/prob_imbalance_loss.py
class Prob_Imbalance_Loss(torch.nn.Module):
    r"""An implementation of the probabilistic imbalance loss function from the
    `DIGRAC: Digraph Clustering Based on Flow Imbalance <https://proceedings.mlr.press/v198/he22b.html>`_ paper.

    Args:
        F (int or NumPy array, optional) - Number of pairwise imbalance socres to consider, or the meta-graph adjacency matrix.
    """

    def __init__(self, F: Optional[Union[int, np.ndarray]] = None):
        super(Prob_Imbalance_Loss, self).__init__()
        if isinstance(F, int):
            self.sel = F
        elif F is not None:
            K = F.shape[0]
            self.sel = 0
            for i in range(K - 1):
                for j in range(i + 1, K):
                    if (F[i, j] + F[j, i]) > 0:
                        self.sel += 1

    def forward(
        self,
        P: torch.FloatTensor,
        A: Union[torch.FloatTensor, torch.sparse_coo_tensor],
        K: int,
        normalization: str = "vol_sum",
        threshold: str = "sort",
    ) -> torch.FloatTensor:
        """Making a forward pass of the probabilistic imbalance loss function from the
        `DIGRAC: Digraph Clustering Based on Flow Imbalance" <https://arxiv.org/pdf/2106.05194.pdf>`_ paper.
            Arg types:
                * **prob** (PyTorch FloatTensor) - Prediction probability matrix made by the model
                * **A** (PyTorch FloatTensor, can be sparse) - Adjacency matrix A
                * **K** (int) - Number of clusters
                * **normalization** (str, optional) - normalization method:

                    'vol_sum': Normalized by the sum of volumes, the default choice.

                    'vol_max': Normalized by the maximum of volumes.

                    'vol_min': Normalized by the minimum of volumes.

                    'plain': No normalization, just CI.
                * **threshold**: (str, optional) normalization method:

                    'sort': Picking the top beta imbalnace values, the default choice.

                    'std': Picking only the terms 3 standard deviation away from null hypothesis.

                    'naive': No thresholding, suming up all K*(K-1)/2 terms of imbalance values.

            Return types:
                * **loss** (torch.Tensor) - loss value, roughly in [0,1].
        """
        assert normalization in [
            "vol_sum",
            "vol_min",
            "vol_max",
            "plain",
        ], "Please input the correct normalization method name!"
        assert threshold in [
            "sort",
            "std",
            "naive",
        ], "Please input the correct threshold method name!"

        device = A.device
        # avoid zero volumn to be denominator
        epsilon = torch.FloatTensor([1e-8]).to(device)
        # first calculate the probabilitis volumns for each cluster
        vol = torch.zeros(K).to(device)
        for k in range(K):
            vol[k] = torch.sum(
                torch.matmul(A + torch.transpose(A, 0, 1), P[:, k : k + 1])
            )
        second_max_vol = torch.topk(vol, 2).values[1] + epsilon
        result = torch.zeros(1).to(device)
        imbalance = []
        if threshold == "std":
            imbalance_std = []
        for k in range(K - 1):
            for l in range(k + 1, K):
                w_kl = torch.matmul(P[:, k], torch.matmul(A, P[:, l]))
                w_lk = torch.matmul(P[:, l], torch.matmul(A, P[:, k]))
                if (w_kl - w_lk).item() != 0:
                    if (
                        threshold != "std"
                        or np.power((w_kl - w_lk).item(), 2) - 9 * (w_kl + w_lk).item()
                        > 0
                    ):
                        if normalization == "vol_sum":
                            curr = (
                                torch.abs(w_kl - w_lk) / (vol[k] + vol[l] + epsilon) * 2
                            )
                        elif normalization == "vol_min":
                            curr = (
                                torch.abs(w_kl - w_lk)
                                / (w_kl + w_lk)
                                * torch.min(vol[k], vol[l])
                                / second_max_vol
                            )
                        elif normalization == "vol_max":
                            curr = torch.abs(w_kl - w_lk) / (
                                torch.max(vol[k], vol[l]) + epsilon
                            )
                        elif normalization == "plain":
                            curr = torch.abs(w_kl - w_lk) / (w_kl + w_lk)
                        imbalance.append(curr)
                    else:  # below-threshold values in the 'std' thresholding scheme
                        if normalization == "vol_sum":
                            curr = (
                                torch.abs(w_kl - w_lk) / (vol[k] + vol[l] + epsilon) * 2
                            )
                        elif normalization == "vol_min":
                            curr = (
                                torch.abs(w_kl - w_lk)
                                / (w_kl + w_lk)
                                * torch.min(vol[k], vol[l])
                                / second_max_vol
                            )
                        elif normalization == "vol_max":
                            curr = torch.abs(w_kl - w_lk) / (
                                torch.max(vol[k], vol[l]) + epsilon
                            )
                        elif normalization == "plain":
                            curr = torch.abs(w_kl - w_lk) / (w_kl + w_lk)
                        imbalance_std.append(curr)
        imbalance_values = [curr.item() for curr in imbalance]
        if threshold == "sort":
            # descending order
            ind_sorted = np.argsort(-np.array(imbalance_values))
            for ind in ind_sorted[: int(self.sel)]:
                result += imbalance[ind]
            # take negation to be minimized
            return torch.ones(1, requires_grad=True).to(device) - result / self.sel
        elif len(imbalance) > 0:
            return torch.ones(1, requires_grad=True).to(device) - torch.mean(
                torch.FloatTensor(imbalance)
            )
        elif threshold == "std":  # sel is 0, then disregard thresholding
            return torch.ones(1, requires_grad=True).to(device) - torch.mean(
                torch.FloatTensor(imbalance_std)
            )
        else:  # nothing has positive imbalance
            return torch.ones(1, requires_grad=True).to(device)


# TODO: signed networks
class NeuralSBM(torch.nn.Module):
    def __init__(
        self,
        in_channels: int,
        max_clusters: int,
        nn: torch.nn.Module,
        max_clusters_per_node: int = None,
        overlapping: bool = True,
        directed: bool = False,
        no_blocks: bool = False,
        fixed_blocks: Tensor = None,
        sparse: bool = True,
    ):
        super().__init__()

        self.directed = directed
        self.fixed_blocks = fixed_blocks
        self.no_blocks = no_blocks
        self.sparse = sparse

        if not overlapping:
            self.max_clusters_per_node = 1
        else:
            self.max_clusters_per_node = max_clusters_per_node

        out_channels = max_clusters

        if not self.no_blocks:
            if self.fixed_blocks is None:
                if self.directed:
                    out_channels += 2 * max_clusters
                else:
                    out_channels += max_clusters
            else:
                if not isinstance(self.fixed_blocks, torch.Tensor):
                    raise ValueError("Fixed blocks must be a tensor!")

        self.nn = nn

    def top_k(
        self,
        clusters: Tensor,
    ):
        if self.max_clusters_per_node is None:
            return clusters
        else:
            values, indices = torch.topk(clusters, k=self.max_clusters_per_node, dim=-1)
            mask = torch.zeros_like(clusters, dtype=torch.bool)
            mask.scatter_(dim=-1, index=indices, value=True)
            return clusters * mask

    def forward(
        self,
        s: Tensor,
    ):
        """

        Currently only supports sparse batching so x is a tensor of shape (n, d) where n is the number of nodes and d is the number of attributes.

        """
        if self.no_blocks:
            block_matrix = None
            clusters = s
        else:
            if self.fixed_blocks is None:
                if self.directed:
                    blocks_in, blocks_out = self.nn(s).chunk(2, dim=-1)
                    block_matrix = torch.matmul(blocks_in.transpose(0, 1), blocks_out)
                else:
                    blocks = self.nn(s)
                    block_matrix = torch.matmul(blocks.transpose(0, 1), blocks)

                if self.sparse:
                    block_matrix = torch.nn.functional.hardsigmoid(block_matrix)
                else:
                    block_matrix = torch.nn.functional.sigmoid(block_matrix)
            else:
                block_matrix = self.fixed_blocks

            clusters = s

        self.clusters = self.top_k(
            clusters.relu()
        )  # TODO: check if relu is needed here for weighted graphs. If softmax is used, relu is not needed here. For unweighted graphs, relu is not needed here. Also if weighted graphs are normalized, relu is not needed here. Normalization is not required yet here as (1) for unweighted graphs, cross entropy loss supports unnormalized logits and (2) for weighted graphs, graphs are not necessarily normalized.
        self.block_matrix = block_matrix

        return self.clusters, self.block_matrix

    def get_adj(self):
        if self.no_blocks:
            pred_adj = torch.matmul(
                self.clusters, self.clusters.transpose(0, 1)
            )  # (n, k) @ (k, n) -> (n, n)
        else:
            pred_adj = torch.matmul(
                torch.matmul(self.clusters, self.block_matrix),
                self.clusters.transpose(0, 1),
            )  # (n, k) @ (k, k) @ (k, n) -> (n, n)
        return pred_adj


class NeuralSBMPool(torch.nn.Module):
    def __init__(
        self,
        in_channels: int,
        max_clusters: int,
        # NeuralSBM args
        bnn: torch.nn.Module,  # nn(s) -> b
        max_clusters_per_node: int = None,
        overlapping: bool = True,
        directed: bool = False,
        weighted: bool = False,
        no_blocks: bool = False,
        fixed_blocks: Tensor = None,
        sparse: bool = False,
        # SBMPool link prediction loss args
        graph_reconstruction_method: str = "sample",
        num_neg_samples: int = None,
        observed_graph_is_dense: bool = False,
        # MapEqPool args
        adj: Tensor = None,
        # other args
        pooling_method: str = "SBMPool",
        regularization_method: str = None,
        nn: torch.nn.Module = None,  # nn(x) -> s
    ):
        super().__init__()

        self.weighted = weighted
        self.sparse = sparse

        if pooling_method == "NOCD":
            pooling_method = "SBMPool"
            no_blocks = True

        self.no_blocks = no_blocks

        self.graph_reconstruction_method = graph_reconstruction_method
        self.num_neg_samples = num_neg_samples
        self.observed_graph_is_dense = observed_graph_is_dense

        if pooling_method == "SBMPool":
            self.sbm = NeuralSBM(
                in_channels=in_channels,
                max_clusters=max_clusters,
                max_clusters_per_node=max_clusters_per_node,
                nn=bnn,
                overlapping=overlapping,
                directed=directed,
                no_blocks=no_blocks,
                fixed_blocks=fixed_blocks,
                sparse=sparse,
            )

        if pooling_method == "MapEqPool":
            if adj is None:
                raise ValueError("adj must be provided for MapEqPool!")
            self.map_eq_pool = NeuromapPooling(adj=adj)

        if pooling_method == "DIGRAC":
            self.prob_imbalance_loss = Prob_Imbalance_Loss()

        if "AsymCheegerCutPool" in [pooling_method, regularization_method]:
            self.asym_cheeger_cut_pool = AsymCheegerCutPool(k=max_clusters)

        self.nn = nn
        self.pooling_method = pooling_method

        if regularization_method is None:
            self.regularization_method = pooling_method
        else:
            self.regularization_method = regularization_method

        # print classes of all modules
        # print("\n")
        # print("snn:", self.nn, "\n")
        # print("bnn:", bnn, "\n")
        # print("\n")

    def forward(
        self,
        x: Tensor,
        s: Tensor = None,
        edge_index: Optional[Tensor] = None,
        edge_weight: Optional[Tensor] = None,
        # TODO: support more general edge_attr: Optional[Tensor] = None,
        batch: Optional[Tensor] = None,
    ):
        if self.nn is None:
            if s is None:
                raise ValueError("s must be provided if nn is None!")
        else:
            if x is None:
                raise ValueError("x must be provided if nn is not None!")
            else:
                s = self.nn(x)

        loss = 0.0

        if self.pooling_method == "SBMPool":
            clusters, block_matrix = self.sbm(s)

            if self.sparse:
                pool = torch.nn.functional.normalize(clusters, p=1, dim=-1)
            else:
                pool = torch.nn.functional.softmax(clusters, dim=-1)

            if self.no_blocks:
                pred_pooled_adj = torch.matmul(
                    torch.matmul(pool.transpose(0, 1), clusters),
                    torch.matmul(clusters.transpose(0, 1), pool),
                )
            else:
                pred_pooled_adj = torch.matmul(
                    torch.matmul(
                        torch.matmul(pool.transpose(0, 1), clusters),
                        block_matrix,
                    ),
                    torch.matmul(clusters.transpose(0, 1), pool),
                )

            pooled_x = torch.matmul(pool.transpose(0, 1), x)

            if edge_index is not None:
                if edge_weight is None:
                    edge_weight = torch.ones(
                        edge_index.size(1), device=edge_index.device
                    )

                if self.graph_reconstruction_method not in [
                    "full",
                    "sample",
                    "pseudosparse",
                ]:
                    raise ValueError("Link prediction loss method not implemented!")

                num_nodes = x.size(0)

                if self.graph_reconstruction_method == "sample":
                    if batch is not None:
                        samples = batched_negative_sampling(
                            batch=batch,
                            edge_index=edge_index,
                            num_neg_samples=self.num_neg_samples,
                            method=(
                                "dense" if self.observed_graph_is_dense else "sparse"
                            ),
                            num_nodes=num_nodes,
                        )
                    else:
                        samples = negative_sampling(
                            edge_index=edge_index,
                            num_neg_samples=self.num_neg_samples,
                            method=(
                                "dense" if self.observed_graph_is_dense else "sparse"
                            ),
                            num_nodes=num_nodes,
                        )

                    edge_index = torch.cat([edge_index, samples], dim=-1)
                    edge_weight = torch.cat(
                        [
                            edge_weight,
                            torch.zeros(samples.size(1), device=edge_weight.device),
                        ],
                        dim=-1,
                    )

                # TODO: Implement maximum sparsity budget by selecting top k from clusters_unpool and block_matrix
                # Could do this by setting a max_clusters_per_node parameter which constrains the number of clusters per node vs the total number of clusters (max_clusters) https://pytorch.org/docs/stable/generated/torch.topk.html
                if self.graph_reconstruction_method == "pseudosparse":
                    if self.no_blocks:
                        pred_adj = torch.sparse.mm(
                            clusters.to_sparse(), clusters.transpose(0, 1).to_sparse()
                        )
                    else:
                        pred_adj = torch.sparse.mm(
                            clusters.to_sparse(),
                            torch.sparse.mm(
                                block_matrix.to_sparse(),
                                clusters.transpose(0, 1).to_sparse(),
                            ),
                        )

                    if not pred_adj.is_coalesced():
                        pred_adj = pred_adj.coalesce()

                    samples = pred_adj.indices()
                    num_samples = samples.shape[-1]

                if self.graph_reconstruction_method == "full":
                    adj = to_dense_adj(edge_index)[0]

                    pred_adj = self.sbm.get_adj()

                    if self.weighted:
                        link_loss = adj - pred_adj
                        link_loss = F.mse_loss(pred_adj, adj)
                    else:
                        link_loss = F.cross_entropy(pred_adj, adj)
                else:
                    if self.no_blocks:
                        sources = clusters
                    else:
                        sources = torch.matmul(clusters, block_matrix)

                    # sparse sampled dense dense matmul
                    sources = sources[edge_index[0]]
                    targets = clusters[edge_index[1]]
                    pred_edges = torch.sum(sources * targets, dim=-1)

                    if self.weighted:
                        link_loss = F.mse_loss(pred_edges, edge_weight)
                    else:
                        link_loss = F.binary_cross_entropy_with_logits(
                            pred_edges, edge_weight
                        )

                loss += link_loss

            pooled_adj = pred_pooled_adj

            # allow custom normalization for semi-supervised learning
            pool = clusters
        else:
            pool = s.softmax(dim=-1)

        if self.pooling_method == "MapEqPool":
            pooled_x, pooled_adj, codelength_loss = self.map_eq_pool(x=x, s=pool)
            loss += codelength_loss

        if self.pooling_method == "DIGRAC":
            A = to_torch_coo_tensor(edge_index, edge_weight, x.size(0))
            pooled_x = torch.matmul(pool.transpose(0, 1), x)
            pooled_adj = torch.matmul(torch.matmul(pool.transpose(0, 1), A), pool)
            prob_imbalance_loss = self.prob_imbalance_loss(
                P=pool,
                A=A,
                K=pool.size(-1),
                threshold="std",
            )
            loss += prob_imbalance_loss

        if not set([self.pooling_method, self.regularization_method]).isdisjoint(
            ["DiffPool", "MinCutPool", "DMoNPool", "AsymCheegerCutPool"]
        ):
            # following dense pooling usage in https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_mincut_pool.py
            num_nodes = x.size(0)
            x, mask = to_dense_batch(x, batch)
            if edge_index is not None:
                adj = to_dense_adj(
                    edge_index,
                    edge_attr=edge_weight,
                    batch=batch,
                    max_num_nodes=num_nodes,
                )

            if (
                self.pooling_method == "DiffPool"
                or self.regularization_method == "DiffPool"
            ):
                _pooled_x, _pooled_adj, link_loss, ent_loss = dense_diff_pool(
                    x=x, adj=adj, s=pool, mask=mask
                )
                if self.pooling_method == "DiffPool":
                    loss += link_loss
                    pooled_x = _pooled_x
                    pooled_adj = _pooled_adj
                if self.regularization_method == "DiffPool":
                    loss += ent_loss

            if (
                self.pooling_method == "MinCutPool"
                or self.regularization_method == "MinCutPool"
            ):
                _pooled_x, _pooled_adj, mincut_loss, ortho_loss = dense_mincut_pool(
                    x=x, adj=adj, s=pool, mask=mask
                )
                if self.pooling_method == "MinCutPool":
                    loss += mincut_loss
                    pooled_x = _pooled_x
                    pooled_adj = _pooled_adj
                if self.regularization_method == "MinCutPool":
                    loss += ortho_loss

            # issue with PyG DMoN implementation https://github.com/pyg-team/pytorch_geometric/issues/9413
            if (
                self.pooling_method == "DMoNPool"
                or self.regularization_method == "DMoNPool"
            ):
                (_pooled_x, _pooled_adj, modularity_loss, cluster_loss) = (
                    dense_dmon_pool(x=x, adj=adj, s=pool, mask=mask)
                )
                if self.pooling_method == "DMoNPool":
                    loss += modularity_loss
                    pooled_x = _pooled_x
                    pooled_adj = _pooled_adj
                if self.regularization_method == "DMoNPool":
                    loss += cluster_loss

            if (
                self.pooling_method == "AsymCheegerCutPool"
                or self.regularization_method == "AsymCheegerCutPool"
            ):
                (
                    _pooled_x,
                    _pooled_adj,
                    tv_loss,
                    bal_loss,
                ) = self.asym_cheeger_cut_pool(x=x, adj=adj, s=pool, mask=mask)
                if self.pooling_method == "AsymCheegerCutPool":
                    loss += tv_loss
                    pooled_x = _pooled_x
                    pooled_adj = _pooled_adj
                if self.regularization_method == "AsymCheegerCutPool":
                    loss += bal_loss

        if self.pooling_method is None:
            pooled_x = None
            pooled_adj = None

        return pool, pooled_x, pooled_adj, loss


class AutoPool(torch.nn.Module):
    def __init__(
        self,
        max_clusters: int,
        pooling_method: str = "SBMPool",
        regularization_method: str = None,
        order: str = "contextual",  # "attributes-first" or "graph-first" or "contextual"
        gnn: torch.nn.Module = None,
        graph_channels: int = None,
        num_attributes: int = None,
        device: torch.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        ),
        # neural(-prior) SBM args
        snn: torch.nn.Module = None,
        bnn: torch.nn.Module = None,
        max_clusters_per_node=None,
        overlapping: bool = True,
        observed_graph_is_dense: bool = False,
        sparse_sbm: bool = True,
        graph_reconstruction_method: str = "sample",
        num_neg_samples: int = None,
        directed: bool = False,
        weighted_graph: bool = False,
        fixed_blocks: Tensor = None,
        no_blocks: bool = False,
        # MapEqPool args
        adj: Tensor = None,
        # attribute reconstruction args
        xnn: torch.nn.Module = None,
        reconstruct_attributes: bool = False,
        sample_attributes: bool = False,
        num_attribute_samples: int = None,
        weighted_x: bool = True,
        # bayesian gnn args
        bayesian: bool = True,
    ):
        super().__init__()

        self.pooling_method = pooling_method
        self.order = order
        self.device = device
        # SBM args
        self.sparse_sbm = sparse_sbm
        # attribute reconstruction args
        self.reconstruct_attributes = reconstruct_attributes
        self.sample_attributes = sample_attributes
        self.num_attribute_samples = num_attribute_samples
        self.weighted_x = weighted_x

        if num_attributes is None:
            if reconstruct_attributes:
                raise ValueError(
                    "num_attributes must be provided for attribute reconstruction!"
                )
        else:
            attribute_channels = num_attributes

        if order == "attributes-first":
            self.reconstruct_attributes = True
            graph_channels = max_clusters

        cluster_channels = max_clusters

        if pooling_method == "NOCD":
            pooling_method = "SBMPool"
            no_blocks = True

        if pooling_method == "SBMPool":
            cluster_channels += max_clusters * (2 if directed else 1)

        self.xnn = xnn
        if self.xnn is None and order != "attributes-first":
            if reconstruct_attributes:
                self.xnn = MLP(
                    in_channels=cluster_channels,
                    out_channels=attribute_channels,
                )
            attribute_channels = cluster_channels

        if gnn is None:
            self.gnn = GCN2(
                in_channels=graph_channels,
                out_channels=attribute_channels,
            )
        else:
            self.gnn = gnn

        if pooling_method == "SBMPool":
            if (
                order == "attributes-first"
                or (order == "graph-first" and reconstruct_attributes)
            ) and snn is None:
                snn = MLP(
                    in_channels=attribute_channels,
                    out_channels=cluster_channels,
                )
            if bnn is None and not no_blocks:
                bnn = MLP(
                    in_channels=cluster_channels,
                    out_channels=cluster_channels,
                )

        self.sbm_pool = NeuralSBMPool(
            in_channels=attribute_channels,
            max_clusters=max_clusters,
            pooling_method=pooling_method,
            regularization_method=regularization_method,
            nn=(None if order == "contextual" else snn),
            # NeuralSBM args
            bnn=bnn,
            max_clusters_per_node=max_clusters_per_node,
            overlapping=overlapping,
            directed=directed,
            weighted=weighted_graph,
            no_blocks=no_blocks,
            fixed_blocks=fixed_blocks,
            sparse=sparse_sbm,
            # link prediction loss args
            graph_reconstruction_method=graph_reconstruction_method,
            num_neg_samples=num_neg_samples,
            observed_graph_is_dense=observed_graph_is_dense,
            # MapEqPool args
            adj=adj,
        )

        # print classes of all modules
        # print("\n")
        # print("gnn:", self.gnn, "\n")
        # print("xnn:", self.xnn, "\n")
        # print("\n")

    # TODO: Check if this is correct for dense to sparse batching
    def mask_to_batch(self, mask: Tensor):
        if mask is None:
            return None
        else:
            return torch.arange(mask.size(0), device=mask.device).repeat_interleave(
                mask.sum(dim=1)
            )

    def forward(
        self,
        x: Tensor = None,
        # sparse
        edge_index: Optional[Tensor] = None,
        edge_weight: Optional[Tensor] = None,
        # TODO: support more general edge_attr: Optional[Tensor] = None,
        batch: Optional[Tensor] = None,
    ):
        """

        Sparse pooling is supported for SBMPool and dense pooling is supported for dense baseline pooling methods (e.g. DiffPool, MinCutPool, DMoNPool).

        Sparse pooling uses sparse batching and edge_index and (optionally) edge_weight or edge_attr inputs e.g. TopKPooling
        https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.TopKPooling.html#torch_geometric.nn.pool.TopKPooling
        https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_topk_pool.py
        Sparse batching follows https://pytorch-geometric.readthedocs.io/en/stable/advanced/batching.html

        Dense pooling uses dense batching and adj input e.g. dense_diff_pool
        https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.dense.dense_diff_pool.html#torch_geometric.nn.dense.dense_diff_pool
        https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_diff_pool.py
        With dense pooling methods, sparse inputs (and sparse batching) are supported by converting to dense inputs after GNN using to_dense_batch and to_dense_adj.

        TODO: Support dense pooling for SBMPool; need to implement dense to sparse batching (sparse to dense batching is already implemented in PyG).

        """
        if self.order == "attributes-first":
            pool, pred_pooled_x, pred_pooled_adj, loss = self.sbm_pool(
                x=x,
                edge_index=edge_index,
                edge_weight=edge_weight,
                batch=batch,
            )

            if self.sparse_sbm:
                unpool = torch.nn.functional.normalize(pool, p=1, dim=1)
            else:
                unpool = torch.nn.functional.softmax(pool, dim=1)
        else:
            pred_x = self.gnn(
                x=x,
                edge_index=edge_index,
                edge_weight=edge_weight,
            )  # (n, n) @ (n, d) @ (d, d) -> (n, d)

        if self.order == "attributes-first":
            edge_index, edge_weight = dense_to_sparse(pred_pooled_adj)
            pred_x = torch.matmul(
                unpool,
                self.gnn(
                    x=pred_pooled_adj,
                    edge_index=edge_index,
                    edge_weight=edge_weight,
                ),
            )
            pred_pooled_x = torch.matmul(pool.transpose(0, 1), pred_x)
        else:
            pool, pred_pooled_x, pred_pooled_adj, loss = self.sbm_pool(
                x=(x if self.order == "contextual" else pred_x),
                s=(
                    pred_x
                    if self.order == "contextual"
                    or (
                        self.order == "graph-first"
                        and self.reconstruct_attributes is False
                    )
                    else None
                ),
                edge_index=edge_index,
                edge_weight=edge_weight,
                batch=batch,
            )

        if self.order == "contextual" and self.xnn is not None:
            pred_x = self.xnn(pred_x)

        # attribute reconstruction loss
        if self.reconstruct_attributes:
            if x is None:
                raise ValueError("x must be provided for attribute reconstruction!")

            if self.sample_attributes:
                pos_indices = torch.nonzero(x != 0)
                num_pos = len(pos_indices)
                num_neg = x.numel() - num_pos

                if self.num_attribute_samples is None:
                    num_samples = min(num_pos, num_neg)
                else:
                    num_samples = min(num_pos, num_neg, self.num_attribute_samples)

                neg_indices = negative_sampling(
                    edge_index=pos_indices.T,
                    num_neg_samples=num_samples,
                    method="dense",
                ).T
                pos_indices = pos_indices[torch.randperm(num_pos)[:num_samples]]

                indices = torch.cat([pos_indices, neg_indices], dim=0).T
                samples = x.gather(dim=1, index=indices)
                pred_x = pred_x.gather(dim=1, index=indices)

            if self.weighted_x:
                attribute_loss = F.mse_loss(input=pred_x, target=x)
            else:
                attribute_loss = F.cross_entropy(x, pred_x)

            loss = loss + attribute_loss

        return pool, pred_pooled_x, pred_pooled_adj, loss
