import torch
from torch_geometric.nn import ARMAConv
from torch_geometric.nn.resolver import activation_resolver

from tgp.poolers import get_pooler
from tgp.mp import GTVConv


class ClusterModel(torch.nn.Module):
    """Torch module consisting of a stack of MP layers followed by a pooling layer.

    Args:
        in_channels (int): Size of node features
        num_layers_pre (int): Number of MP layers before pooling
        hidden_channels (int): Dimensionality of node embeddings
        activation (str): Activation of the MP layers
        pooler (str): Pooling method
        pool_kwargs (dict): Pooling method kwargs
        pooled_nodes (int): Number of nodes after pooling
    """

    def __init__(
        self,
        in_channels,  # Size of node features
        num_layers_pre=1,  # Number of MP layers before pooling
        hidden_channels=64,  # Dimensionality of node embeddings
        activation="ELU",  # Activation of the MP layers
        pooler=None,  # Pooling method
        pool_kwargs=None,  # Pooling method kwargs
        pooled_nodes=None,  # Number of nodes after pooling
    ):
        super().__init__()

        self.num_layers_pre = num_layers_pre
        self.hidden_channels = hidden_channels
        self.act = activation_resolver(activation)
        self.pool = get_pooler(
            pooler, in_channels=hidden_channels, k=pooled_nodes, **pool_kwargs
        )

        assert self.pool.has_loss, "The pooler must have an auxiliary loss."

        # Pre-pooling block
        self.conv_layers_pre = torch.nn.ModuleList()
        for _ in range(num_layers_pre):
            if pooler in ["acc"]:
                self.conv_layers_pre.append(
                    GTVConv(
                        in_channels=in_channels,
                        out_channels=hidden_channels,
                        delta_coeff=0.3,
                    )
                )
            else:
                self.conv_layers_pre.append(
                    ARMAConv(
                        in_channels=in_channels,
                        out_channels=hidden_channels,
                    )
                )
            in_channels = hidden_channels

    def forward(self, data):
        ### pre-pooling block
        x = data.x
        for layer in self.conv_layers_pre:
            x = self.act(layer(x, data.edge_index, data.edge_attr))

        ### pooling block
        _, adj, _ = self.pool.preprocessing(
            x=x, edge_index=data.edge_index, edge_attr=data.edge_attr, use_cache=True
        )
        out = self.pool(x=x, adj=adj)

        return out
