import torch
import torch.nn.functional as F
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.nn import GINConv, GINEConv, MLP, DenseGINConv, PANConv
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from tgp.poolers import get_pooler, pooler_map
from tgp.reduce import BaseReduce
from tgp.src import PoolingOutput


class OGBModel(torch.nn.Module):
    def __init__(
        self,
        in_channels,  # Size of node features
        out_channels,  # Number of classes
        pooler,  # Pooling method
        edge_channels=None,  # Size of edge features
        num_layers_pre=1,  # Number of GIN layers before pooling
        num_layers_post=1,  # Number of GIN layers after pooling
        hidden_channels=64,  # Dimensionality of node embeddings
        activation="relu",  # Activations
        dropout=0.0,  # Dropout in the MLP
        pool_kwargs=None,  # Pooling method kwargs
        pooled_nodes=None,  # Number of nodes after pooling
        use_gine=False,  # Use GINE instead of GIN
    ):
        super().__init__()

        self.num_classes = out_channels
        self.act = activation_resolver(activation)
        self.pooler = pooler
        if edge_channels is not None and use_gine:
            self.using_gine = True
        else:
            self.using_gine = False
        self.dropout = dropout

        ### Node and edge embeddings for ogb
        self.atom_encoder = AtomEncoder(100)
        self.bond_encoder = BondEncoder(100)
        in_channels = 100
        edge_channels = 100

        ### Pre-pooling block
        self.conv_layers_pre = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()
        for _ in range(num_layers_pre):
            if pooler == "pan":
                # PANConv for PAN pooling (requires filter_size parameter)
                filter_size = pool_kwargs.get("filter_size", 3) if pool_kwargs else 3
                self.conv_layers_pre.append(
                    PANConv(in_channels=in_channels, out_channels=hidden_channels, filter_size=filter_size)
                )
            else:
                # MP layers
                mlp = MLP(
                    [in_channels, hidden_channels, hidden_channels],
                    act=activation,
                    norm=None,
                )
                if self.using_gine:
                    self.conv_layers_pre.append(
                        GINEConv(nn=mlp, train_eps=False, edge_dim=edge_channels)
                    )
                else:
                    self.conv_layers_pre.append(GINConv(nn=mlp, train_eps=False))
            in_channels = hidden_channels

            # BatchNorm layers
            self.batch_norms.append(torch.nn.BatchNorm1d(hidden_channels))

        ### Pooling block
        self.pooler = get_pooler(
            pooler, in_channels=in_channels, k=pooled_nodes, **pool_kwargs
        )

        ### Post-pooling block
        self.conv_layers_post = torch.nn.ModuleList()
        for _ in range(num_layers_post):
            mlp = MLP(
                [hidden_channels, hidden_channels, hidden_channels],
                act=activation,
                norm=None,
            )
            if self.pooler.is_dense:
                self.conv_layers_post.append(DenseGINConv(nn=mlp))
            else:
                self.conv_layers_post.append(GINConv(nn=mlp))

        ### Readout
        self.mlp = MLP(
            [hidden_channels, hidden_channels // 2, out_channels],
            act=activation,
            norm=None,
        )

    def forward(self, data):
        """
        ⏩
        """
        x = data.x
        adj = data.edge_index
        ea = getattr(data, "edge_attr", None)
        batch = data.batch

        ### Node and edge embeddings
        x = self.atom_encoder(x)
        ea_emb = self.bond_encoder(ea)

        ### Pre-pooling block
        for mp_layer, bn_layer in zip(self.conv_layers_pre, self.batch_norms):
            if isinstance(self.pooler, pooler_map["pan"]):
                # PANConv returns (x, M) where M is used by PAN pooler
                x, M = mp_layer(x, adj)
                x = self.act(x)
            elif self.using_gine:
                x = self.act(mp_layer(x, adj, edge_attr=ea_emb))
            else:
                x = self.act(mp_layer(x, adj))
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = bn_layer(x)

        ### Pooling block
        if not self.pooler.is_trainable:
            x, _ = BaseReduce()(x=x, so=data.pooled_data.so)
            adj = data.pooled_data.edge_index
            out = PoolingOutput(
                x=x,
                so=data.pooled_data.so,
                edge_index=adj,
                batch=data.pooled_data.batch,
            )
        else:
            if isinstance(self.pooler, pooler_map["pan"]):
                # PAN pooler uses the matrix M from PANConv
                out = self.pooler(x=x, adj=M, batch=batch)
                x, adj = out.x, out.edge_index
            else:
                # Standard pooling flow for other poolers
                x, adj, mask = self.pooler.preprocessing(
                    x=x,
                    edge_index=adj,
                    edge_weight=None,
                    batch=batch,
                    use_cache=False,
                )
                out = self.pooler(x=x, adj=adj, edge_weight=None, batch=batch, mask=mask)
                x, adj = out.x, out.edge_index

        ### Post-pooling block
        for layer in self.conv_layers_post:
            x = self.act(layer(x, adj))

        ### Readout
        x = self.pooler.global_pool(x, reduce_op="sum", batch=out.batch)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.mlp(x)

        return x, out
