import torch
from torch_geometric.graphgym import cfg
from torch_geometric.graphgym.register import register_node_encoder

from grit.slt.sparse_modules import SparseLinear, SparseLinearMulti_mask


@register_node_encoder("LinearNode")
class LinearNodeEncoder(torch.nn.Module):
    def __init__(self, emb_dim):
        super().__init__()

        if cfg.slt.sm is True and cfg.slt.encoder is True:
            self.encoder = SparseLinear(cfg.share.dim_in, emb_dim)
        elif cfg.slt.mm is True and cfg.slt.encoder is True:
            self.encoder = SparseLinearMulti_mask(cfg.share.dim_in, emb_dim)
        else:
            self.encoder = torch.nn.Linear(cfg.share.dim_in, emb_dim)

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        if (
            cfg.slt.sm is True or cfg.slt.mm is True
        ) and cfg.slt.encoder is True:
            if cfg.slt.pruning == "blockwise":
                batch.x = self.encoder(batch.x, encoder_th)
            elif cfg.slt.pruning == "global":
                batch.x = self.encoder(batch.x, global_th)
        else:
            batch.x = self.encoder(batch.x)

        return batch
