import torch
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import (register_node_encoder,
                                               register_edge_encoder)
from ..utils import embed_1D_scalar

"""
Generic Node and Edge encoders for datasets with node/edge features that
consist of only one type dictionary thus require a single nn.Embedding layer.

The number of possible Node and Edge types must be set by cfg options:
1) cfg.dataset.nnode_types
2) cfg.dataset.nedge_types
"""


@register_node_encoder('OCBNode')
class OCBNodeEncoder(torch.nn.Module):
    def __init__(self, emb_dim):
        super().__init__()

        self.emb_dim = emb_dim
        num_types = cfg.dataset.nnode_types
        if num_types < 1:
            raise ValueError(f"Invalid 'nnode_types': {num_types}")

        if cfg.framework.type == 'vfm':
            self.type_encoder = torch.nn.Linear(num_types, emb_dim)
        else:
            self.type_encoder = torch.nn.Embedding(num_embeddings=num_types, embedding_dim=emb_dim)

        # Node feature prediction (device sizing)? ---> LEGACY, replaced by cfg.gt.sizing which encodes in batch.x_features
        if cfg.dataset.node_features_dim > 1:
            self.feature_encoder = torch.nn.Sequential(
                torch.nn.Linear((cfg.dataset.node_features_dim - 1) * emb_dim, 2 * emb_dim),
                torch.nn.SiLU()
            )

        if cfg.gt.get("sizing", False):
            self.feature_encoder = torch.nn.Sequential(
                torch.nn.Linear(emb_dim, emb_dim),
                torch.nn.SiLU()
            )


    def forward(self, batch):
        if cfg.framework.type == 'vfm':
            type_emb = self.type_encoder(batch.xt_logits)
        else:
            # Encode just the first dimension as node type
            type_emb = self.type_encoder(batch.x[:, 0].long())

        if cfg.dataset.node_features_dim > 1:
            feature_emb = embed_1D_scalar(batch.x_features[:, 0], self.emb_dim, 200)
            gamma, beta = self.feature_encoder(feature_emb).chunk(2, dim=-1)
            batch.x = type_emb * (1 + gamma) + beta
        elif cfg.gt.get("sizing", False):
            batch.x = type_emb
            feature_emb = embed_1D_scalar(batch.x_features[:, 0], self.emb_dim, 200)
            batch.x_features = self.feature_encoder(feature_emb)
        else:
            batch.x = type_emb
        return batch


@register_node_encoder('AnalogGenieNode')
class AnalogGenieNodeEncoder(torch.nn.Module):
    def __init__(self, emb_dim):
        super().__init__()

        self.emb_dim = emb_dim
        num_types = cfg.dataset.nnode_types
        if num_types < 1:
            raise ValueError(f"Invalid 'nnode_types': {num_types}")

        if cfg.framework.type == 'vfm':
            self.type_encoder = torch.nn.Linear(num_types, emb_dim)
        else:
            self.type_encoder = torch.nn.Embedding(num_embeddings=num_types, embedding_dim=emb_dim)

    def forward(self, batch):
        if cfg.framework.type == 'vfm':
            batch.x = self.type_encoder(batch.xt_logits)
        else:
            if batch.x[:, 0].dtype != torch.int64:
                # Encode just the first dimension as node type
                batch.x = self.type_encoder(batch.x[:, 0].long())
            else:
                batch.x = self.type_encoder(batch.x[:, 0].long())
        return batch


# @register_node_encoder('TypeDictNode')
# class TypeDictNodeEncoder(torch.nn.Module):
#     def __init__(self, emb_dim):
#         super().__init__()

#         assert emb_dim % 2 == 0, "Embedding dimension must be even"
#         num_types = cfg.dataset.nnode_types
#         num_features = cfg.dataset.nnode_features
#         if num_types < 1:
#             raise ValueError(f"Invalid 'nnode_types': {num_types}")

#         if cfg.framework.type == 'vfm':
#             self.type_encoder = torch.nn.Linear(num_types, emb_dim)
#         # if we have just node type all the embedding dimension will be used for the type_encoder, otherwise half of it will be used for the type_encoder and the other half for the feature_encoder
#         elif cfg.dataset.node_features_dim == 1:
#             self.type_encoder = torch.nn.Embedding(num_embeddings=num_types,
#                                           embedding_dim=emb_dim)
#         else:   
#             self.type_encoder = torch.nn.Embedding(num_embeddings=num_types,
#                                             embedding_dim=emb_dim//2) # TODO or expansion -> reduction
            
#         self.feature_encoder = torch.nn.Linear(1, emb_dim//2) # TODO we want to encode this more properly -> rbf here too

#         # suppose a embedding dimension of 64, then 32 will be used for the type_encoder
#         # and 32 will be used for the feature_encoder

        
#         # torch.nn.init.xavier_uniform_(self.encoder.weight.data)

#     def forward(self, batch):
#         if cfg.framework.type == 'vfm':
#             type_emb = self.type_encoder(batch.xt_logits)
#         else:
#             # Encode just the first dimension as node type
#             type_emb = self.type_encoder(batch.x[:, 0])
#         # Encode the rest as node features
#         # TODO: Check if mean is the right aggregation TODO: THIS SHOULD BE UPDATED ASAP
#         if cfg.dataset.node_features_dim > 1:
#             feature_emb = self.feature_encoder(batch.x[:, 1:cfg.dataset.node_features_dim].float()) 
#             batch.x = torch.cat([type_emb, feature_emb], dim=1)
#         else:
#             batch.x = type_emb
#         return batch


@register_edge_encoder('TypeDictEdge')
class TypeDictEdgeEncoder(torch.nn.Module):
    def __init__(self, emb_dim):
        super().__init__()

        num_types = cfg.dataset.nedge_types
        if num_types < 1:
            raise ValueError(f"Invalid 'nedge_types': {num_types}")

        if cfg.framework.type == 'vfm':
            self.encoder = torch.nn.Linear(2, emb_dim)
        else:
            self.encoder = torch.nn.Embedding(num_embeddings=num_types, embedding_dim=emb_dim)
        # torch.nn.init.xavier_uniform_(self.encoder.weight.data)

    def forward(self, batch):
        batch.edge_attr = self.encoder(batch.edge_attr)
        return batch
