import torch
from torch_geometric.graphgym import cfg
from torch_geometric.graphgym.register import register_edge_encoder

from graphgps.slt.sparse_modules import SparseEmbedding


@register_edge_encoder("DummyEdge")
class DummyEdgeEncoder(torch.nn.Module):
    def __init__(self, emb_dim):
        super().__init__()

        if cfg.slt.embedding is True:
            if (
                cfg.slt.sm is True or cfg.slt.mm is True
            ) and cfg.slt.encoder is True:
                self.encoder = SparseEmbedding(emb_dim=emb_dim, feature_dims=1)
            else:
                self.encoder = torch.nn.Embedding(
                    num_embeddings=1, embedding_dim=emb_dim
                )
        else:
            self.encoder = torch.nn.Embedding(
                num_embeddings=1, embedding_dim=emb_dim
            )
        # torch.nn.init.xavier_uniform_(self.encoder.weight.data)

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        dummy_attr = batch.edge_index.new_zeros(batch.edge_index.shape[1])

        if cfg.slt.embedding is True:
            if (
                cfg.slt.sm is True or cfg.slt.mm is True
            ) and cfg.slt.encoder is True:
                if cfg.slt.pruning == "global":
                    batch.edge_attr = self.encoder(dummy_attr, global_th)
                elif cfg.slt.pruning == "blockwise":
                    batch.edge_attr = self.encoder(dummy_attr, encoder_th)
            else:
                batch.edge_attr = self.encoder(dummy_attr)
        else:
            batch.edge_attr = self.encoder(dummy_attr)

        return batch
