import torch.nn as nn
import torch_geometric.data
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.layer import MLP, new_layer_config
from torch_geometric.graphgym.register import register_head

from graphgps.slt.monarch_linear import MonarchLinear
from graphgps.slt.sparse_modules import (
    NMSparseMultiLinear,
    SparseLinear,
    SparseLinearMulti_mask,
)


class CustomSequential(nn.Module):
    def __init__(self, *args):
        super(CustomSequential, self).__init__()
        self.mlpmodules = nn.ModuleList(args)

    def forward(self, x, y=None):
        flag = True
        for module in self.mlpmodules:
            if isinstance(module, nn.ReLU):
                x = module(x)
            else:
                if isinstance(x, torch_geometric.data.Batch) and flag:
                    batch_x = x.x
                    flag = False
                if (
                    cfg.slt.sm is True or cfg.slt.mm is True
                ) and cfg.slt.encoder is True:
                    batch_x = module(batch_x, y)
                else:
                    batch_x = module(batch_x)
        x.x = batch_x
        return x


@register_head("inductive_node")
class GNNInductiveNodeHead(nn.Module):
    """
    GNN prediction head for inductive node prediction tasks.

    Args:
        dim_in (int): Input dimension
        dim_out (int): Output dimension. For binary prediction, dim_out=1.
    """

    def __init__(self, dim_in, dim_out):
        super(GNNInductiveNodeHead, self).__init__()

        if cfg.slt.encoder is True:
            if cfg.monarch.encoder:
                layer_class = MonarchLinear
            elif cfg.slt.srste:
                layer_class = NMSparseMultiLinear
            elif cfg.slt.sm:
                layer_class = SparseLinear
            elif cfg.slt.mm:
                layer_class = SparseLinearMulti_mask

            layers = [
                layer_class(dim_in, dim_in, bias=False),
                layer_class(dim_in, dim_in, bias=False),
                layer_class(dim_in, dim_out, bias=False),
            ]

            self.layer_post_mp = CustomSequential(*layers)
        else:
            self.layer_post_mp = MLP(
                new_layer_config(
                    dim_in,
                    dim_out,
                    cfg.gnn.layers_post_mp,
                    has_act=False,
                    has_bias=True,
                    cfg=cfg,
                )
            )

    def _apply_index(self, batch):
        return batch.x, batch.y

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        if (
            cfg.slt.sm is True or cfg.slt.mm is True
        ) and cfg.slt.encoder is True:
            if cfg.slt.pruning == "blockwise":
                batch = self.layer_post_mp(batch, pred_th)
            elif cfg.slt.pruning == "global":
                batch = self.layer_post_mp(batch, global_th)
        else:
            batch = self.layer_post_mp(batch)
        pred, label = self._apply_index(batch)
        return pred, label
