import torch
import torch.nn as nn
from torch_geometric.data import Batch
from torch_geometric.utils import to_dense_batch
from torch_geometric.nn import norm
from torch_geometric.utils import to_dense_adj
from torch_geometric.graphgym.models.gnn import GNNPreMP
from torch_geometric.graphgym.models.layer import new_layer_config, BatchNorm1dNode
import torch_geometric.nn as pygnn
from torch_geometric.nn import Linear as Linear_pyg
import torch.nn.functional as F
from .node_encoders import node_encoder_dict
from .heads import head_dict
from .attention import Attention
from .edge_encoders import edge_encoder_dict
from .async_mpnn import FlowLayer
from .edgerwse_encoder import RWSEEdgeEncoder



class GPSLayer(nn.Module):
    """
     Reference: https://github.com/rampasek/GraphGPS.
    """

    def __init__(
        self,
        dim_h,
        local_gnn_type,
        global_model_type,
        num_heads,
        pna_degrees=None,
        equivstable_pe=False,
        dropout=0.0,
        attn_dropout=0.0,
        layer_norm=False,
        batch_norm=True,
        bigbird_cfg=None,
        log_attn_weights=False,
        dag_cfg=None,
    ):
        super().__init__()

        self.dim_h = dim_h
        self.num_heads = num_heads
        self.attn_dropout = attn_dropout
        self.layer_norm = layer_norm
        self.batch_norm = batch_norm
        self.equivstable_pe = equivstable_pe
        self.activation = nn.ReLU

        self.log_attn_weights = log_attn_weights
        if log_attn_weights and global_model_type not in [
            "Transformer",
            "BiasedTransformer",
        ]:
            raise NotImplementedError(
                f"Logging of attention weights is not supported " f"for '{global_model_type}' global attention model."
            )
        self.attn_after_mpnn = dag_cfg.attn_after_mpnn
        self.ff = dag_cfg.ff

        # Local message-passing model.
        self.local_gnn_with_edge_attr = True
        if local_gnn_type == "None":
            self.local_model = None

        elif local_gnn_type == "FLOW":
            self.local_model = Flowlayer(
                dim_h,
                dim_h,
                dropout,
                dag_cfg.bidirectional,
                dag_cfg.conv_type,
                dag_cfg=dag_cfg,
            )
        else:
            raise ValueError(f"Unsupported local GNN model: {local_gnn_type}")
        self.local_gnn_type = local_gnn_type

        # Global attention transformer-style model.
        if global_model_type == "None":
            self.self_attn = None
        elif global_model_type == "HA":
            self.self_attn = Attention(dim_h, num_heads, dropout=self.attn_dropout, bias=False)
        elif global_model_type in ['Transformer']:
            self.self_attn = torch.nn.MultiheadAttention(
                dim_h, num_heads, dropout=self.attn_dropout, batch_first=True)
        else:
            raise ValueError(f"Unsupported global x-former model: " f"{global_model_type}")
        self.global_model_type = global_model_type

        if self.layer_norm and self.batch_norm:
            raise ValueError("Cannot apply two types of normalization together")

        if self.layer_norm:
            self.norm1_local = norm.LayerNorm(dim_h)
            self.norm1_attn = norm.LayerNorm(dim_h)
        if self.batch_norm:
            self.norm1_local = nn.BatchNorm1d(dim_h)
            self.norm1_attn = nn.BatchNorm1d(dim_h)
        self.dropout_local = nn.Dropout(dropout)
        self.dropout_attn = nn.Dropout(dropout)

        # Feed Forward block.
        self.ff_linear1 = nn.Linear(dim_h, dim_h * 2)
        self.ff_linear2 = nn.Linear(dim_h * 2, dim_h)
        self.act_fn_ff = self.activation()

        if self.layer_norm:
            self.norm2 = norm.LayerNorm(dim_h)
        if self.batch_norm:
            self.norm2 = nn.BatchNorm1d(dim_h)
        self.ff_dropout1 = nn.Dropout(dropout)
        self.ff_dropout2 = nn.Dropout(dropout)

    def forward(self, batch):
        h = batch.x
        h_in1 = h

        h_out_list = []
        if self.local_model is not None:
            if self.local_gnn_type == "FLOW":
                local_out = self.local_model(batch)
                h_local = local_out.x
                h_local = self.dropout_local(h_local)
                h_local = h_in1 + h_local  # Residual connection.


            elif self.local_gnn_type == 'GCN':
                h_local = self.local_model(h, batch.edge_index)
                h_local = self.dropout_local(h_local)
                h_local = h_in1 + h_local  # Residual connection.

            if self.layer_norm:
                h_local = self.norm1_local(h_local, batch.batch)
            if self.batch_norm:
                h_local = self.norm1_local(h_local)
            if self.attn_after_mpnn:
                h = h_local
            h_out_list.append(h_local)

        # Multi-head attention.
        if self.self_attn is not None:
            if self.global_model_type == "HA":
                h_attn, _ = self.self_attn(
                    h,
                    dag_rr_edge_index=batch.s_edge_index
                )


            elif self.global_model_type == 'Transformer':
                h_dense, mask = to_dense_batch(h, batch.batch)
                num_nodes_per_instance = batch.Eigvecs.shape[0] // batch.y.shape[0]
                s_adj = to_dense_adj(batch.edge_index,batch.batch,max_num_nodes=int(num_nodes_per_instance))
                att_mask = s_adj == 1
                attn_mask_multihead = att_mask.repeat_interleave(self.num_heads, dim=0)
                h_attn = self._sa_block(h_dense, attn_mask_multihead, ~mask)[mask]
            else:
                raise RuntimeError(f"Unexpected {self.global_model_type}")

            h_attn = self.dropout_attn(h_attn)
            if self.attn_after_mpnn:
                h_attn = h_local + h_attn
                h_attn = self.norm1_attn(h_attn)
            h_attn = h_in1 + h_attn  # Residual connection.
            if self.layer_norm:
                h_attn = self.norm1_attn(h_attn, batch.batch)
            if self.batch_norm:
                h_attn = self.norm1_attn(h_attn)
            h_out_list.append(h_attn)

        # Combine local and global outputs.
        if self.self_attn is not None and self.attn_after_mpnn:
            h = h_attn
        else:
            h = sum(h_out_list)

        # Feed Forward block.
        h = h + self._ff_block(h)
        if self.layer_norm:
            h = self.norm2(h, batch.batch)
        if self.batch_norm:
            h = self.norm2(h)

        batch.x = h
        return batch

    def _sa_block(self, x, attn_mask, key_padding_mask):
        if not self.log_attn_weights:
            x = self.self_attn(
                x,
                x,
                x,
                attn_mask=attn_mask,
                key_padding_mask=key_padding_mask,
                need_weights=False,
            )[0]
        else:
            x, A = self.self_attn(
                x,
                x,
                x,
                attn_mask=attn_mask,
                key_padding_mask=key_padding_mask,
                need_weights=True,
                average_attn_weights=False,
            )
            self.attn_weights = A.detach().cpu()
        return x

    def _ff_block(self, x):
        x = self.ff_dropout1(self.act_fn_ff(self.ff_linear1(x)))
        return self.ff_dropout2(self.ff_linear2(x))

    def extra_repr(self):
        s = (
            f"summary: dim_h={self.dim_h}, "
            f"local_gnn_type={self.local_gnn_type}, "
            f"global_model_type={self.global_model_type}, "
            f"heads={self.num_heads}"
        )
        return s


class FeatureEncoder(torch.nn.Module):
    """
    Encoding node and edge features

    Args:
        dim_in (int): Input feature dimension
    """

    def __init__(self, dim_in, cfg):
        super(FeatureEncoder, self).__init__()
        self.dim_in = dim_in
        if cfg.dataset.node_encoder:
            # Encode integer node features via nn.Embeddings
            NodeEncoder = node_encoder_dict[cfg.dataset.node_encoder_name]

            self.node_encoder = NodeEncoder(cfg.gnn.dim_inner, cfg)
            if cfg.dataset.node_encoder_bn:
                self.node_encoder_bn = BatchNorm1dNode(
                    new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False, has_bias=False, cfg=cfg)
                )
            # Update dim_in to reflect the new dimension of the node features
            self.dim_in = cfg.gnn.dim_inner
        if cfg.dataset.edge_encoder:
            cfg.gnn.dim_edge = cfg.gnn.dim_inner
            EdgeEncoder = RWSEEdgeEncoder
            self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge, cfg)
            if cfg.dataset.edge_encoder_bn:
                self.edge_encoder_bn = BatchNorm1dNode(
                    new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False, has_bias=False, cfg=cfg)
                )

    def forward(self, batch):
        for module in self.children():
            batch = module(batch)
        return batch


class TEFormer(torch.nn.Module):
    def __init__(self, dim_in, dim_out, cfg):
        super().__init__()
        self.encoder = FeatureEncoder(dim_in, cfg)
        dim_in = self.encoder.dim_in

        if cfg.gnn.layers_pre_mp > 0:
            self.pre_mp = GNNPreMP(dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
            dim_in = cfg.gnn.dim_inner

        if not cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in:
            raise ValueError(
                f"The inner and hidden dims must match: "
                f"embed_dim={cfg.gt.dim_hidden} dim_inner={cfg.gnn.dim_inner} "
                f"dim_in={dim_in}"
            )

        try:
            local_gnn_type, global_model_type = cfg.gt.layer_type.split("+")
        except Exception:
            raise ValueError(f"Unexpected layer type: {cfg.gt.layer_type}")
        layers = []
        for _ in range(cfg.gt.layers):
            layers.append(
                GPSLayer(
                    dim_h=cfg.gt.dim_hidden,
                    local_gnn_type=local_gnn_type,
                    global_model_type=global_model_type,
                    num_heads=cfg.gt.n_heads,
                    pna_degrees=cfg.gt.pna_degrees,
                    equivstable_pe=cfg.posenc_EquivStableLapPE.enable,
                    dropout=cfg.gt.dropout,
                    attn_dropout=cfg.gt.attn_dropout,
                    layer_norm=cfg.gt.layer_norm,
                    batch_norm=cfg.gt.batch_norm,
                    bigbird_cfg=cfg.gt.bigbird,
                    log_attn_weights=cfg.train.mode == "log-attn-weights",
                    dag_cfg=cfg.dag,
                )
            )
        self.layers = torch.nn.Sequential(*layers)

        GNNHead = head_dict[cfg.gnn.head]
        self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out, cfg=cfg)

    def forward(self, batch):
        for module in self.children():
            batch = module(batch)
        return batch

    def embedding(self, batch):
        with torch.no_grad():
            for module in list(self.children())[:-1]:
                batch = module(batch)
            batch = self.post_mp.embedding(batch)
        return batch
