import torch
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_network
from ..encoder.node_features_encoder import FeatureEncoder, OneHotPerturb
from ..layer.output_layer import OutputLayer

@register_network('GritTransformer')
class GritTransformer(torch.nn.Module):
    '''
        Adapted from `Graph Inductive Biases in Transformers without Message Passing` (https://arxiv.org/pdf/2305.17589)
    '''

    def __init__(self, dim_in, dim_out, model_cfg=None):
        '''
        dim_in: token dimension
        dim_spec_out: the number of conditioning quantities
        dim_bins_out: the number of classes for each spec
        '''
        super().__init__()
        if model_cfg is None:
            model_cfg = cfg
        self.model_cfg = model_cfg
        dim_spec_out, dim_bins_out = model_cfg.gnn.get('n_spec', None), model_cfg.gnn.get('n_bins', None)

        # if model_cfg.dataset.task_type in ['classification', 'regression']:
        #     self.one_hot_perturb = OneHotPerturb()
        self.encoder = FeatureEncoder(dim_in, model_cfg)
        # Graph features encoder: encodes both conditioning features and time
        # if model_cfg.dataset.task_type == 'generative': 
        self.graph_features_encoder = register.layer_dict["time_encoder"](model_cfg.gnn.dim_inner)

        dim_in = self.encoder.dim_in
        node_dim_out = model_cfg.dataset.nnode_types
        edge_dim_out = model_cfg.dataset.nedge_types

        if model_cfg.posenc_RRWP.enable:
            self.rrwp_abs_encoder = register.node_encoder_dict["rrwp_linear"] \
                (model_cfg.posenc_RRWP.ksteps, model_cfg.gnn.dim_inner)
            rel_pe_dim = model_cfg.posenc_RRWP.ksteps
            self.rrwp_rel_encoder = register.edge_encoder_dict["rrwp_linear"] \
                (rel_pe_dim, model_cfg.gnn.dim_edge,
                 pad_to_full_graph=model_cfg.gt.attn.full_attn,
                 add_node_attr_as_self_loop=False,
                 fill_value=0.
                 )

        assert model_cfg.gt.dim_hidden == model_cfg.gnn.dim_inner == dim_in, \
            "The inner and hidden dims must match."

        global_model_type = model_cfg.gt.get('layer_type', "GritTransformer")

        TransformerLayer = register.layer_dict.get(global_model_type)

        # Boolean inputs to the tf layer
        conditioning = model_cfg.gt.time_conditioning
        spec_conditioning = (model_cfg.gt.conditional_gen and model_cfg.gt.get('sep_t_spec_cond', False)) \
                                if model_cfg.dataset.task_type == 'generative' else False
        features = model_cfg.gt.get("sizing", False) and (model_cfg.gt.get("process_feats_with_x", False) == False)
        coupling = model_cfg.gt.attn.get("x_f_coupling", False) and features

        layers = []
        for l in range(model_cfg.gt.layers):
            layers.append(TransformerLayer(
                in_dim=model_cfg.gt.dim_hidden,
                out_dim=model_cfg.gt.dim_hidden,
                num_heads=model_cfg.gt.n_heads,
                dropout=model_cfg.gt.dropout,
                act=model_cfg.gnn.act,
                attn_dropout=model_cfg.gt.attn_dropout,
                layer_norm=model_cfg.gt.layer_norm,
                batch_norm=model_cfg.gt.batch_norm,
                residual=True,
                norm_e=model_cfg.gt.attn.norm_e,
                O_e=model_cfg.gt.attn.O_e,
                cfg=model_cfg.gt,
                features=features,
                coupling=coupling,
                conditioning=conditioning,
                spec_conditioning=spec_conditioning,
                ignore_edges=cfg.dataset.get('ignore_edges', False)
            ))

        self.layers = torch.nn.Sequential(*layers)
        if model_cfg.dataset.task_type == 'generative': 
            self.output_layer = OutputLayer(model_cfg.gt.dim_hidden, node_dim_out, model_cfg.gt.dim_hidden, edge_dim_out)
        elif model_cfg.dataset.task_type =='classification':
            GNNHead = register.head_dict[model_cfg.gnn.head]
            self.post_mp = GNNHead(dim_in=model_cfg.gt.dim_hidden, dim_spec_out=dim_spec_out, dim_bins_out=dim_bins_out, 
                                   dual_head=model_cfg.gnn.get('dual_head', False))
                                   

    def forward(self, batch, unconditional_prop=0.0):
        # if (not cfg.framework.type == 'vfm') and (self.model_cfg.dataset.task_type == 'generative'):
        #     # batch.x = batch.x.int()
        #     batch.edge_attr = batch.edge_attr.int()
        # batch.edge_attr = batch.edge_attr.int() 
        for module in self.children():
            if module.__class__.__name__ == 'FeatureEncoder':
                kwargs = {'unconditional_prop': unconditional_prop}
            else:
                kwargs = {}
            batch = module(batch, **kwargs)

        return batch
