import torch
from torch_geometric.graphgym import cfg
from torch_geometric.graphgym.register import register_edge_encoder

from graphgps.slt.monarch_linear import MonarchLinear
from graphgps.slt.sparse_modules import (
    NMSparseMultiLinear,
    SparseLinear,
    SparseLinearMulti_mask,
)


@register_edge_encoder("LinearEdge")
class LinearEdgeEncoder(torch.nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        if cfg.dataset.name in ["MNIST", "CIFAR10"]:
            self.in_dim = 1
        else:
            raise ValueError(
                "Input edge feature dim is required to be hardset "
                "or refactored to use a cfg option."
            )

        if cfg.slt.encoder is True:
            if cfg.monarch.encoder is True:
                self.encoder = MonarchLinear(self.in_dim, emb_dim, bias=False)
            elif cfg.slt.srste is True:
                self.encoder = NMSparseMultiLinear(self.in_dim, emb_dim)
            elif cfg.slt.sm is True:
                self.encoder = SparseLinear(self.in_dim, emb_dim, bias=False)
            elif cfg.slt.mm is True:
                self.encoder = SparseLinearMulti_mask(
                    self.in_dim, emb_dim, bias=False
                )
        else:
            self.encoder = torch.nn.Linear(self.in_dim, emb_dim)

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        if (
            cfg.slt.sm is True or cfg.slt.mm is True
        ) and cfg.slt.encoder is True:
            if cfg.slt.pruning == "blockwise":
                batch.edge_attr = self.encoder(
                    batch.edge_attr.view(-1, self.in_dim), encoder_th
                )
            elif cfg.slt.pruning == "global":
                batch.edge_attr = self.encoder(
                    batch.edge_attr.view(-1, self.in_dim), global_th
                )
        else:
            batch.edge_attr = self.encoder(
                batch.edge_attr.view(-1, self.in_dim)
            )
        return batch
