import torch.nn as nn
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.layer import MLP, new_layer_config
from torch_geometric.graphgym.register import register_head

from grit.slt.sparse_modules import SparseLinear, SparseLinearMulti_mask


@register_head("inductive_node")
class GNNInductiveNodeHead(nn.Module):
    """
    GNN prediction head for inductive node prediction tasks.

    Args:
        dim_in (int): Input dimension
        dim_out (int): Output dimension. For binary prediction, dim_out=1.
    """

    def __init__(self, dim_in, dim_out):
        super(GNNInductiveNodeHead, self).__init__()
        if (cfg.slt.sm is True or cfg.slt.mm is True) and cfg.slt.pred is True:
            self.layers = nn.ModuleList()
            if cfg.slt.sm is True:
                for _ in range(cfg.gnn.layers_post_mp - 1):
                    self.layers.append(SparseLinear(dim_in, dim_in))
                self.layers.append(SparseLinear(dim_in, dim_out))
            elif cfg.slt.mm is True:
                for _ in range(cfg.gnn.layers_post_mp - 1):
                    self.layers.append(SparseLinearMulti_mask(dim_in, dim_in))
                self.layers.append(SparseLinearMulti_mask(dim_in, dim_out))

        else:
            self.layer_post_mp = MLP(
                new_layer_config(
                    dim_in,
                    dim_out,
                    cfg.gnn.layers_post_mp,
                    has_act=False,
                    has_bias=True,
                    cfg=cfg,
                )
            )

    def _apply_index(self, batch):
        return batch.x, batch.y

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        if pred_th is not None:
            threshold = pred_th
        elif global_th is not None:
            threshold = global_th
        if hasattr(self, "layers"):
            x = batch.x
            for layer in self.layers:
                x = layer(x, threshold) if threshold is not None else layer(x)
            pred = x
            label = batch.y
        else:
            batch = self.layer_post_mp(batch)
            pred, label = self._apply_index(batch)
        return pred, label
