import torch
import numpy as np
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import (register_node_encoder,
                                               register_edge_encoder)
from ..utils import embed_1D_scalar

"""
Generic Node and Edge encoders for datasets with node/edge features that
consist of only one type dictionary thus require a single nn.Embedding layer.

The number of possible Node and Edge types must be set by cfg options:
1) cfg.dataset.nnode_types
2) cfg.dataset.nedge_types
"""


@register_node_encoder('OCBNode')
class OCBNodeEncoder(torch.nn.Module):
    def __init__(self, emb_dim, model_cfg=None):
        super().__init__()

        self.model_cfg = cfg if model_cfg is None else model_cfg

        self.emb_dim = emb_dim
        num_types = self.model_cfg.dataset.nnode_types
        if num_types < 1:
            raise ValueError(f"Invalid 'nnode_types': {num_types}")

        if (self.model_cfg.framework.type == 'vfm'):
            self.type_encoder = torch.nn.Linear(num_types, emb_dim)
        else:
            self.type_encoder = torch.nn.Embedding(num_embeddings=num_types, embedding_dim=emb_dim)

        # Sizes
        if self.model_cfg.gt.get("sizing", False):
            self.feature_encoder = torch.nn.Sequential(
                torch.nn.Linear(emb_dim, emb_dim),
                torch.nn.SiLU()
            )

        # Supernodes
        if cfg.gt.get('conditional_gen', False) and cfg.gt.get('supernode', False):
            self.supernode_encoder = torch.nn.Embedding(num_embeddings=model_cfg.gnn.n_bins ** model_cfg.gnn.n_spec, embedding_dim=emb_dim)
            self.null_token = torch.nn.Parameter(torch.zeros(1, emb_dim), requires_grad=True)
            # Initialize the null token
            torch.nn.init.normal_(self.null_token, mean=0.0, std=0.02)


    def forward(self, batch, unconditional_prop=0.0):
        if self.model_cfg.framework.type == 'vfm':
            type_emb = self.type_encoder(batch.xt_logits)
        else:
            # Encode just the first dimension as node type
            type_emb = self.type_encoder(batch.x[:, 0].clamp(max=cfg.dataset.nnode_types - 1).long())

        # Maybe encode both node type and feature into batch.x
        if self.model_cfg.gt.get("sizing", False) and self.model_cfg.gt.get("process_feats_with_x", False):
            # shift = 50 if self.model_cfg.dataset.get('scaled_features', False) else 1 # Map x_features from [-1; 1] to [0; 100]
            # bias = 50 if self.model_cfg.dataset.get('scaled_features', False) else 0
            feature_emb = embed_1D_scalar(50 + 50 * batch.x_features[:, 0], self.emb_dim, 200)
            feature_emb = self.feature_encoder(feature_emb)#.chunk(2, dim=-1)
            # Suppress uninformative node features (In, Out, net)
            zero_mask = (batch.x[:, 0] >= 8)
            # gamma = gamma * zero_mask[:, None]
            # beta = beta * zero_mask[:, None]
            feature_emb = feature_emb * zero_mask[:, None]
            batch.x = type_emb + feature_emb # * (1 + gamma) + beta
        else:
            if hasattr(self, 'supernode_encoder'):
                supernode_emb = self.supernode_encoder(batch.x[batch.supernode_x_index, 0].long())
                # Random masking for unconditional generation
                uncond_idx = np.random.choice(np.arange(len(supernode_emb)), int(unconditional_prop * len(supernode_emb)), replace=False)
                supernode_emb[uncond_idx] = self.null_token
                # Put back in batch.x
                type_emb[batch.supernode_x_index] = supernode_emb
            # Else encode node features into batch.x_features
            if self.model_cfg.gt.get("sizing", False):
                feature_emb = embed_1D_scalar(batch.x_features[:, 0], self.emb_dim, 200)
                batch.x_features = self.feature_encoder(feature_emb)
            batch.x = type_emb
            
        return batch


@register_node_encoder('AnalogGenieNode')
class AnalogGenieNodeEncoder(torch.nn.Module):
    def __init__(self, emb_dim, **kwargs):
        super().__init__()

        self.emb_dim = emb_dim
        num_types = cfg.dataset.nnode_types
        if num_types < 1:
            raise ValueError(f"Invalid 'nnode_types': {num_types}")

        if cfg.framework.type == 'vfm':
            self.type_encoder = torch.nn.Linear(num_types, emb_dim)
        else:
            self.type_encoder = torch.nn.Embedding(num_embeddings=num_types, embedding_dim=emb_dim)

    def forward(self, batch):
        if cfg.framework.type == 'vfm':
            batch.x = self.type_encoder(batch.xt_logits)
        else:
            if batch.x[:, 0].dtype != torch.int64:
                # Encode just the first dimension as node type
                batch.x = self.type_encoder(batch.x[:, 0].long())
            else:
                batch.x = self.type_encoder(batch.x[:, 0].long())
        return batch


# @register_node_encoder('TypeDictNode')
# class TypeDictNodeEncoder(torch.nn.Module):
#     def __init__(self, emb_dim):
#         super().__init__()

#         assert emb_dim % 2 == 0, "Embedding dimension must be even"
#         num_types = cfg.dataset.nnode_types
#         num_features = cfg.dataset.nnode_features
#         if num_types < 1:
#             raise ValueError(f"Invalid 'nnode_types': {num_types}")

#         if cfg.framework.type == 'vfm':
#             self.type_encoder = torch.nn.Linear(num_types, emb_dim)
#         # if we have just node type all the embedding dimension will be used for the type_encoder, otherwise half of it will be used for the type_encoder and the other half for the feature_encoder
#         elif cfg.dataset.node_features_dim == 1:
#             self.type_encoder = torch.nn.Embedding(num_embeddings=num_types,
#                                           embedding_dim=emb_dim)
#         else:   
#             self.type_encoder = torch.nn.Embedding(num_embeddings=num_types,
#                                             embedding_dim=emb_dim//2) # TODO or expansion -> reduction
            
#         self.feature_encoder = torch.nn.Linear(1, emb_dim//2) # TODO we want to encode this more properly -> rbf here too

#         # suppose a embedding dimension of 64, then 32 will be used for the type_encoder
#         # and 32 will be used for the feature_encoder

        
#         # torch.nn.init.xavier_uniform_(self.encoder.weight.data)

#     def forward(self, batch):
#         if cfg.framework.type == 'vfm':
#             type_emb = self.type_encoder(batch.xt_logits)
#         else:
#             # Encode just the first dimension as node type
#             type_emb = self.type_encoder(batch.x[:, 0])
#         # Encode the rest as node features
#         # TODO: Check if mean is the right aggregation TODO: THIS SHOULD BE UPDATED ASAP
#         if cfg.dataset.node_features_dim > 1:
#             feature_emb = self.feature_encoder(batch.x[:, 1:cfg.dataset.node_features_dim].float()) 
#             batch.x = torch.cat([type_emb, feature_emb], dim=1)
#         else:
#             batch.x = type_emb
#         return batch


@register_edge_encoder('TypeDictEdge')
class TypeDictEdgeEncoder(torch.nn.Module):
    def __init__(self, emb_dim, model_cfg=None):
        super().__init__()

        self.model_cfg = cfg if model_cfg is None else model_cfg
        num_types = self.model_cfg.dataset.nedge_types
        if num_types < 1:
            raise ValueError(f"Invalid 'nedge_types': {num_types}")

        if self.model_cfg.framework.type == 'vfm':
            self.encoder = torch.nn.Linear(2, emb_dim)
        else:
            self.encoder = torch.nn.Embedding(num_embeddings=num_types, embedding_dim=emb_dim)
        # torch.nn.init.xavier_uniform_(self.encoder.weight.data)

    def forward(self, batch):
        batch.edge_attr = self.encoder(batch.edge_attr)
        return batch
