import torch
import torch.nn as nn
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import GNNPreMP
from torch_geometric.graphgym.models.layer import (
    MLP,
    BatchNorm1dNode,
    new_layer_config,
)
from torch_geometric.graphgym.register import register_head, register_network

# from grit.head.gnngraphhead import GNNGraphHead
from grit.slt.sparse_modules import SparseLinear, SparseLinearMulti_mask


class FeatureEncoder(torch.nn.Module):
    """
    Encoding node and edge features

    Args:
        dim_in (int): Input feature dimension
    """

    def __init__(self, dim_in):
        super(FeatureEncoder, self).__init__()
        self.dim_in = dim_in
        if cfg.dataset.node_encoder:
            # Encode integer node features via nn.Embeddings
            NodeEncoder = register.node_encoder_dict[
                cfg.dataset.node_encoder_name
            ]
            self.node_encoder = NodeEncoder(cfg.gnn.dim_inner)
            if cfg.dataset.node_encoder_bn:
                self.node_encoder_bn = BatchNorm1dNode(
                    new_layer_config(
                        cfg.gnn.dim_inner,
                        -1,
                        -1,
                        has_act=False,
                        has_bias=False,
                        cfg=cfg,
                    )
                )
            # Update dim_in to reflect the new dimension fo the node features
            self.dim_in = cfg.gnn.dim_inner
        if cfg.dataset.edge_encoder:
            # Hard-limit max edge dim for PNA.
            if "PNA" in cfg.gt.layer_type:
                cfg.gnn.dim_edge = min(128, cfg.gnn.dim_inner)
            else:
                cfg.gnn.dim_edge = cfg.gnn.dim_inner
            # Encode integer edge features via nn.Embeddings
            EdgeEncoder = register.edge_encoder_dict[
                cfg.dataset.edge_encoder_name
            ]
            self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge)
            if cfg.dataset.edge_encoder_bn:
                self.edge_encoder_bn = BatchNorm1dNode(
                    new_layer_config(
                        cfg.gnn.dim_edge,
                        -1,
                        -1,
                        has_act=False,
                        has_bias=False,
                        cfg=cfg,
                    )
                )

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        for module in self.children():
            if (
                module.__class__.__name__ == "BondEncoder"
                or module.__class__.__name__ == "AtomEncoder"
            ):
                batch = module(batch)

            # elif module.__class__.__name__ == "LinearEdgeEncoder":
            #     batch = module(batch)
            elif cfg.slt.sm is True or cfg.slt.mm is True:
                batch = module(
                    batch,
                    cur_epoch=cur_epoch,
                    mpnn_th=mpnn_th,
                    msa_th=msa_th,
                    ffn_th=ffn_th,
                    encoder_th=encoder_th,
                    pred_th=pred_th,
                    global_th=global_th,
                )

            else:
                batch = module(batch)

        return batch


class CustomSequential(nn.Module):
    def __init__(self, *args):
        super(CustomSequential, self).__init__()
        self.modules_list = nn.ModuleList(args)

    def forward(
        self,
        input,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        if encoder_th is not None:
            threshold = encoder_th
        elif global_th is not None:
            threshold = global_th

        for module in self.modules_list:
            input = (
                module(input, threshold=threshold)
                if threshold is not None
                else module(input)
            )
        return input


# for SLT in cifar10
@register_head("custom_graph")
class GNNGraphHead(nn.Module):
    """
    GNN prediction head for graph prediction tasks.
    The optional post_mp layer (specified by cfg.gnn.post_mp) is used
    to transform the pooled embedding using an MLP.

    Args:
        dim_in (int): Input dimension
        dim_out (int): Output dimension. For binary prediction, dim_out=1.
    """

    def __init__(self, dim_in, dim_out):
        super().__init__()
        if (cfg.slt.sm is True or cfg.slt.mm is True) and cfg.slt.pred is True:
            self.layers = nn.ModuleList()
            if cfg.slt.sm is True:
                for _ in range(cfg.gnn.layers_post_mp - 1):
                    self.layers.append(SparseLinear(dim_in, dim_in))
                self.layers.append(SparseLinear(dim_in, dim_out))
            elif cfg.slt.mm is True:
                for _ in range(cfg.gnn.layers_post_mp - 1):
                    self.layers.append(SparseLinearMulti_mask(dim_in, dim_in))
                self.layers.append(SparseLinearMulti_mask(dim_in, dim_out))
        else:
            self.layer_post_mp = MLP(
                new_layer_config(
                    dim_in,
                    dim_out,
                    cfg.gnn.layers_post_mp,
                    has_act=False,
                    has_bias=True,
                    cfg=cfg,
                )
            )
        self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling]

    def _apply_index(self, batch):
        return batch.graph_feature, batch.y

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        graph_emb = self.pooling_fun(batch.x, batch.batch)
        if pred_th is not None:
            threshold = pred_th
        elif global_th is not None:
            threshold = global_th

        if hasattr(self, "layers"):
            for layer in self.layers:
                graph_emb = (
                    layer(graph_emb, threshold)
                    if threshold is not None
                    else layer(graph_emb)
                )
        else:
            graph_emb = self.layer_post_mp(graph_emb)
        batch.graph_feature = graph_emb
        pred, label = self._apply_index(batch)
        return pred, label


@register_network("GritTransformer")
class GritTransformer(torch.nn.Module):
    """
    The proposed GritTransformer (Graph Inductive Bias Transformer)
    """

    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.encoder = FeatureEncoder(dim_in)
        dim_in = self.encoder.dim_in

        self.ablation = True
        self.ablation = False

        if cfg.posenc_RRWP.enable:
            self.rrwp_abs_encoder = register.node_encoder_dict["rrwp_linear"](
                cfg.posenc_RRWP.ksteps, cfg.gnn.dim_inner
            )
            rel_pe_dim = cfg.posenc_RRWP.ksteps
            self.rrwp_rel_encoder = register.edge_encoder_dict["rrwp_linear"](
                rel_pe_dim,
                cfg.gnn.dim_edge,
                pad_to_full_graph=cfg.gt.attn.full_attn,
                add_node_attr_as_self_loop=False,
                fill_value=0.0,
            )

        if cfg.gnn.layers_pre_mp > 0:
            self.pre_mp = GNNPreMP(
                dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp
            )
            dim_in = cfg.gnn.dim_inner

        assert (
            cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in
        ), "The inner and hidden dims must match."

        global_model_type = cfg.gt.get("layer_type", "GritTransformer")
        # global_model_type = "GritTransformer"

        TransformerLayer = register.layer_dict.get(global_model_type)

        layers = []
        for layer in range(cfg.gt.layers):
            layers.append(
                TransformerLayer(
                    in_dim=cfg.gt.dim_hidden,
                    out_dim=cfg.gt.dim_hidden,
                    num_heads=cfg.gt.n_heads,
                    dropout=cfg.gt.dropout,
                    act=cfg.gnn.act,
                    attn_dropout=cfg.gt.attn_dropout,
                    layer_norm=cfg.gt.layer_norm,
                    batch_norm=cfg.gt.batch_norm,
                    residual=True,
                    norm_e=cfg.gt.attn.norm_e,
                    O_e=cfg.gt.attn.O_e,
                    cfg=cfg,
                )
            )
        # layers = []

        if cfg.slt.msa is True or cfg.slt.ffn is True:
            self.layers = CustomSequential(*layers)
        else:
            self.layers = torch.nn.Sequential(*layers)

        # for SLT
        if cfg.gnn.head == "graph":
            cfg.gnn.head = "custom_graph"

        GNNHead = register.head_dict[cfg.gnn.head]
        self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        for module in self.children():
            if cfg.slt.sm is True or cfg.slt.mm is True:
                batch = module(
                    batch,
                    cur_epoch,
                    mpnn_th=mpnn_th,
                    msa_th=msa_th,
                    ffn_th=ffn_th,
                    encoder_th=encoder_th,
                    pred_th=pred_th,
                    global_th=global_th,
                )
            else:
                batch = module(batch)
        return batch
