import torch
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_network
from ..encoder.node_features_encoder import FeatureEncoder
from ..layer.output_layer import OutputLayer

@register_network('GritTransformer')
class GritTransformer(torch.nn.Module):
    '''
        GritTransformer (Graph Inductive Bias Transformer) cf !! [INSERT LINK HERE] !!
    '''

    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.encoder = FeatureEncoder(dim_in)
        # Graph features encoder: encodes both conditioning features and time
        if cfg.gt.conditional_gen:
            self.graph_features_encoder = register.layer_dict["graph_features_encoder"](cfg.gnn.dim_inner)
        elif cfg.gt.time_conditioning:
            self.graph_features_encoder = register.layer_dict["time_encoder"](cfg.gnn.dim_inner)

        dim_in = self.encoder.dim_in
        inc = -1 if cfg.train.prior == 'masked' else 0 # mb avoid masking token
        node_dim_out = cfg.dataset.nnode_types + inc
        edge_dim_out = cfg.dataset.nedge_types + inc

        if cfg.posenc_RRWP.enable:
            self.rrwp_abs_encoder = register.node_encoder_dict["rrwp_linear"] \
                (cfg.posenc_RRWP.ksteps, cfg.gnn.dim_inner)
            rel_pe_dim = cfg.posenc_RRWP.ksteps
            self.rrwp_rel_encoder = register.edge_encoder_dict["rrwp_linear"] \
                (rel_pe_dim, cfg.gnn.dim_edge,
                 pad_to_full_graph=cfg.gt.attn.full_attn,
                 add_node_attr_as_self_loop=False,
                 fill_value=0.
                 )

        assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \
            "The inner and hidden dims must match."

        global_model_type = cfg.gt.get('layer_type', "GritTransformer")

        TransformerLayer = register.layer_dict.get(global_model_type)

        layers = []
        for l in range(cfg.gt.layers):
            layers.append(TransformerLayer(
                in_dim=cfg.gt.dim_hidden,
                out_dim=cfg.gt.dim_hidden,
                num_heads=cfg.gt.n_heads,
                dropout=cfg.gt.dropout,
                act=cfg.gnn.act,
                attn_dropout=cfg.gt.attn_dropout,
                layer_norm=cfg.gt.layer_norm,
                batch_norm=cfg.gt.batch_norm,
                residual=True,
                norm_e=cfg.gt.attn.norm_e,
                O_e=cfg.gt.attn.O_e,
                cfg=cfg.gt,
                features=cfg.gt.get("sizing", False),
                coupling=cfg.gt.attn.get("x_f_coupling", False),
                conditioning=cfg.gt.time_conditioning
            ))

        self.layers = torch.nn.Sequential(*layers)
        self.output_layer = OutputLayer(cfg.gt.dim_hidden, node_dim_out, cfg.gt.dim_hidden, edge_dim_out)         

    def forward(self, batch, unconditional_prop=0):
        if not cfg.framework.type == 'vfm':
            # batch.x = batch.x.int()
            batch.edge_attr = batch.edge_attr.int() 
        for module in self.children():
            if module.__class__.__name__ ==  'GraphFeaturesEncoder':
                kwargs = {'unconditional_prop': unconditional_prop}
            else:
                kwargs = {}
            batch = module(batch, **kwargs)

        return batch
