import torch
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.nn import GINConv, GINEConv, MLP, DenseGINConv, PANConv
from tgp.poolers import get_pooler, pooler_map
from tgp.reduce import BaseReduce 
from tgp.src import PoolingOutput



class ClassificationModel(torch.nn.Module):
    def __init__(
        self,
        in_channels,  # Size of node features
        out_channels,  # Number of classes
        edge_channels=None,  # Size of edge features
        num_layers_pre=1,  # Number of GIN layers before pooling
        num_layers_post=1,  # Number of GIN layers after pooling
        hidden_channels=64,  # Dimensionality of node embeddings
        activation="ELU",  # Activation of the MLP in GIN
        pooler=None,  # Pooling method
        pool_kwargs=None,  # Pooling method kwargs
        pooled_nodes=None,  # Number of nodes after pooling
        use_gine=False,  # Use GINE instead of GIN
        dropout=0.5,  # Dropout of the readout
    ):
        super().__init__()

        self.num_classes = out_channels
        self.act = activation_resolver(activation)
        if edge_channels is not None and use_gine:
            self.using_gine = True
        else:
            self.using_gine = False

        ### Pre-pooling block
        self.conv_layers_pre = torch.nn.ModuleList()
        for _ in range(num_layers_pre):
            if pooler == "pan":
                # PANConv for PAN pooling (requires filter_size parameter)
                filter_size = pool_kwargs.get("filter_size", 3) if pool_kwargs else 3
                self.conv_layers_pre.append(
                    PANConv(in_channels=in_channels, out_channels=hidden_channels, filter_size=filter_size)
                )
            else:
                # Standard GIN/GINE layers for other poolers
                mlp = MLP([in_channels, hidden_channels, hidden_channels], act=activation)
                if self.using_gine:
                    self.conv_layers_pre.append(GINEConv(nn=mlp, edge_dim=edge_channels))
                else:
                    self.conv_layers_pre.append(GINConv(nn=mlp))
            in_channels = hidden_channels

        ### Pooling block
        self.pooler = get_pooler(
            pooler, in_channels=in_channels, k=pooled_nodes, **pool_kwargs
        )

        ### Post-pooling block
        self.conv_layers_post = torch.nn.ModuleList()
        for _ in range(num_layers_post):
            mlp = MLP(
                [hidden_channels, hidden_channels, hidden_channels],
                act=activation,
                norm=None,
            )
            if self.pooler.is_dense:
                self.conv_layers_post.append(DenseGINConv(nn=mlp))
            else:
                self.conv_layers_post.append(GINConv(nn=mlp))

        ### Readout
        self.mlp = MLP(
            [hidden_channels, hidden_channels // 2, out_channels],
            act=activation,
            norm=None,
            dropout=dropout,
        )

    def forward(self, data):
        """
        ⏩
        """
        x = data.x
        adj = data.edge_index
        ea = getattr(data, "edge_attr", None)
        batch = data.batch

        ### Pre-pooling block
        for layer in self.conv_layers_pre:
            if isinstance(self.pooler, pooler_map["pan"]):
                # PANConv returns (x, M) where M is used by PAN pooler
                x, M = layer(x, adj)
                x = self.act(x)
            elif self.using_gine:
                x = self.act(layer(x, adj, edge_attr=ea))
            else:
                x = self.act(layer(x, adj))

        ### Pooling block
        if not self.pooler.is_trainable:
            x, _ = BaseReduce()(x=x, so=data.pooled_data.so)
            adj = data.pooled_data.edge_index
            out = PoolingOutput(
                x=x,
                so=data.pooled_data.so,
                edge_index=adj,
                batch=data.pooled_data.batch,
            )
        else:
            if isinstance(self.pooler, pooler_map["pan"]):
                # PAN pooler uses the matrix M from PANConv
                out = self.pooler(x=x, adj=M, batch=batch)
                x, adj = out.x, out.edge_index
            else:
                # Standard pooling flow for other poolers
                if ea is not None:
                    if ea.size(-1) > 1:
                        ea = None
                x, adj, mask = self.pooler.preprocessing(
                    x=x,
                    edge_index=adj,
                    edge_weight=ea,
                    batch=batch,
                    use_cache=False,
                )
                out = self.pooler(x=x, adj=adj, edge_weight=ea, batch=batch, mask=mask)
                x, adj = out.x, out.edge_index

        ### Post-pooling block
        for layer in self.conv_layers_post:
            x = self.act(layer(x, adj))

        ### Readout
        x = self.pooler.global_pool(x, reduce_op="sum", batch=out.batch)
        x = self.mlp(x)

        return x, out
