import torch
import torch.nn as nn

# import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_node_encoder

from graphgps.slt.monarch_linear import MonarchLinear
from graphgps.slt.sparse_modules import (
    NMSparseMultiLinear,
    SparseLinear,
    SparseLinearMulti_mask,
)


class CustomSequential(nn.Module):
    def __init__(self, *args):
        super(CustomSequential, self).__init__()
        self.submodules = nn.ModuleList(args)

    def forward(self, x, y):
        for module in self.submodules:
            if isinstance(module, nn.ReLU):
                x = module(x)
            else:
                x = module(x, y)
        return x


@register_node_encoder("LapPE")
class LapPENodeEncoder(torch.nn.Module):
    """Laplace Positional Embedding node encoder.

    LapPE of size dim_pe will get appended to each node feature vector.
    If `expand_x` set True, original node features will be first linearly
    projected to (dim_emb - dim_pe) size and the concatenated with LapPE.

    Args:
        dim_emb: Size of final node embedding
        expand_x: Expand node features `x` from dim_in to (dim_emb - dim_pe)
    """

    def __init__(self, dim_emb, expand_x=True):
        super().__init__()
        dim_in = cfg.share.dim_in  # Expected original input node features dim

        pecfg = cfg.posenc_LapPE
        dim_pe = pecfg.dim_pe  # Size of Laplace PE embedding
        model_type = pecfg.model  # Encoder NN model type for PEs
        if model_type not in ["Transformer", "DeepSet"]:
            raise ValueError(f"Unexpected PE model {model_type}")
        self.model_type = model_type
        n_layers = pecfg.layers  # Num. layers in PE encoder model
        n_heads = pecfg.n_heads  # Num. attention heads in Trf PE encoder
        post_n_layers = pecfg.post_layers  # Num. layers to apply after pooling
        max_freqs = pecfg.eigen.max_freqs  # Num. eigenvectors (frequencies)
        norm_type = (
            pecfg.raw_norm_type.lower()
        )  # Raw PE normalization layer type
        self.pass_as_var = (
            pecfg.pass_as_var
        )  # Pass PE also as a separate variable

        if (
            dim_emb - dim_pe < 0
        ):  # formerly 1, but you could have zero feature size
            raise ValueError(
                f"LapPE size {dim_pe} is too large for "
                f"desired embedding size of {dim_emb}."
            )

        if expand_x and dim_emb - dim_pe > 0:
            if cfg.slt.encoder is True:
                if cfg.monarch.encoder is True:
                    self.linear_x = MonarchLinear(
                        dim_in, dim_emb - dim_pe, bias=False
                    )
                elif cfg.slt.srste is True:
                    self.linear_x = NMSparseMultiLinear(
                        dim_in, dim_emb - dim_pe
                    )
                elif cfg.slt.sm is True:
                    self.linear_x = SparseLinear(
                        dim_in, dim_emb - dim_pe, bias=False
                    )
                elif cfg.slt.mm is True:
                    self.linear_x = SparseLinearMulti_mask(
                        dim_in, dim_emb - dim_pe, bias=False
                    )
            else:
                self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe)
        self.expand_x = expand_x and dim_emb - dim_pe > 0

        # Initial projection of eigenvalue and the node's eigenvector value
        if cfg.slt.encoder is True:
            if cfg.monarch.encoder is True:
                self.linear_A = MonarchLinear(2, dim_pe, bias=False)
            elif cfg.slt.srste is True:
                self.linear_A = NMSparseMultiLinear(2, dim_pe)
            elif cfg.slt.sm is True:
                self.linear_A = SparseLinear(2, dim_pe, bias=False)
            elif cfg.slt.mm is True:
                self.linear_A = SparseLinearMulti_mask(2, dim_pe, bias=False)
        else:
            self.linear_A = nn.Linear(2, dim_pe)

        if norm_type == "batchnorm":
            self.raw_norm = nn.BatchNorm1d(
                max_freqs, affine=cfg.slt.batch_affine
            )

        else:
            self.raw_norm = None

        activation = nn.ReLU  # register.act_dict[cfg.gnn.act]
        if model_type == "Transformer":
            # Transformer model for LapPE
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=dim_pe, nhead=n_heads, batch_first=True
            )
            self.pe_encoder = nn.TransformerEncoder(
                encoder_layer, num_layers=n_layers
            )
        else:
            # DeepSet model for LapPE
            layers = []
            if n_layers == 1:
                layers.append(activation())
            else:
                if cfg.slt.encoder is True:
                    if cfg.monarch.encoder is True:
                        self.linear_A = MonarchLinear(
                            2, 2 * dim_pe, bias=False
                        )
                    elif cfg.slt.srste is True:
                        self.linear_A = NMSparseMultiLinear(2, 2 * dim_pe)
                    elif cfg.slt.sm is True:
                        self.linear_A = SparseLinear(2, 2 * dim_pe, bias=False)
                    elif cfg.slt.mm is True:
                        self.linear_A = SparseLinearMulti_mask(
                            2, 2 * dim_pe, bias=False
                        )
                else:
                    self.linear_A = nn.Linear(2, 2 * dim_pe)
                layers.append(activation())
                for _ in range(n_layers - 2):
                    if cfg.slt.encoder is True:
                        if cfg.monarch.encoder is True:
                            layers.append(
                                MonarchLinear(
                                    2 * dim_pe, 2 * dim_pe, bias=False
                                )
                            )
                        elif cfg.slt.srste is True:
                            layers.append(
                                NMSparseMultiLinear(2 * dim_pe, 2 * dim_pe)
                            )
                        elif cfg.slt.sm is True:
                            layers.append(
                                SparseLinear(
                                    2 * dim_pe, 2 * dim_pe, bias=False
                                )
                            )
                        elif cfg.slt.mm is True:
                            layers.append(
                                SparseLinearMulti_mask(
                                    2 * dim_pe, 2 * dim_pe, bias=False
                                )
                            )
                    else:
                        layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe))
                    layers.append(activation())

                if cfg.slt.encoder is True:
                    if cfg.monarch.encoder is True:
                        layers.append(
                            MonarchLinear(2 * dim_pe, dim_pe, bias=False)
                        )
                    elif cfg.slt.srste is True:
                        layers.append(NMSparseMultiLinear(2 * dim_pe, dim_pe))
                    elif cfg.slt.sm is True:
                        layers.append(
                            SparseLinear(2 * dim_pe, dim_pe, bias=False)
                        )
                    elif cfg.slt.mm is True:
                        layers.append(
                            SparseLinearMulti_mask(
                                2 * dim_pe, dim_pe, bias=False
                            )
                        )
                else:
                    layers.append(nn.Linear(2 * dim_pe, dim_pe))
                layers.append(activation())
            if (
                cfg.slt.sm is True or cfg.slt.mm is True
            ) and cfg.slt.encoder is True:
                self.pe_encoder = CustomSequential(*layers)
            else:
                self.pe_encoder = nn.Sequential(*layers)

        self.post_mlp = None
        if post_n_layers > 0:
            # MLP to apply post pooling
            layers = []
            if post_n_layers == 1:
                layers.append(nn.Linear(dim_pe, dim_pe))
                layers.append(activation())
            else:
                layers.append(nn.Linear(dim_pe, 2 * dim_pe))
                layers.append(activation())
                for _ in range(post_n_layers - 2):
                    layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe))
                    layers.append(activation())
                layers.append(nn.Linear(2 * dim_pe, dim_pe))
                layers.append(activation())
            self.post_mlp = nn.Sequential(*layers)

    def forward(
        self,
        batch,
        cur_epoch=None,
        mpnn_th=None,
        msa_th=None,
        ffn_th=None,
        encoder_th=None,
        pred_th=None,
        global_th=None,
    ):
        if not (hasattr(batch, "EigVals") and hasattr(batch, "EigVecs")):
            raise ValueError(
                "Precomputed eigen values and vectors are "
                f"required for {self.__class__.__name__}; "
                "set config 'posenc_LapPE.enable' to True"
            )
        EigVals = batch.EigVals
        EigVecs = batch.EigVecs

        if self.training:
            sign_flip = torch.rand(EigVecs.size(1), device=EigVecs.device)
            sign_flip[sign_flip >= 0.5] = 1.0
            sign_flip[sign_flip < 0.5] = -1.0
            EigVecs = EigVecs * sign_flip.unsqueeze(0)

        pos_enc = torch.cat(
            (EigVecs.unsqueeze(2), EigVals), dim=2
        )  # (Num nodes) x (Num Eigenvectors) x 2
        empty_mask = torch.isnan(
            pos_enc
        )  # (Num nodes) x (Num Eigenvectors) x 2

        pos_enc[empty_mask] = 0  # (Num nodes) x (Num Eigenvectors) x 2
        if self.raw_norm:
            pos_enc = self.raw_norm(pos_enc)
        if (cfg.slt.sm is True or cfg.slt.mm is True) and (
            cfg.slt.encoder is True
        ):
            if cfg.slt.pruning == "blockwise":
                pos_enc = self.linear_A(
                    pos_enc,
                    encoder_th,
                )  # (Num nodes) x (Num Eigenvectors) x dim_pe
            elif cfg.slt.pruning == "global":
                pos_enc = self.linear_A(
                    pos_enc,
                    global_th,
                )  # (Num nodes) x (Num Eigenvectors) x dim_pe
        else:
            pos_enc = self.linear_A(
                pos_enc
            )  # (Num nodes) x (Num Eigenvectors) x dim_pe

        # PE encoder: a Transformer or DeepSet model
        if self.model_type == "Transformer":
            pos_enc = self.pe_encoder(
                src=pos_enc, src_key_padding_mask=empty_mask[:, :, 0]
            )
        else:
            if (
                cfg.slt.sm is True or cfg.slt.mm is True
            ) and cfg.slt.encoder is True:
                if cfg.slt.pruning == "blockwise":
                    pos_enc = self.pe_encoder(pos_enc, encoder_th)
                elif cfg.slt.pruning == "global":
                    pos_enc = self.pe_encoder(pos_enc, global_th)
            else:
                pos_enc = self.pe_encoder(pos_enc)

        # Remove masked sequences; must clone before overwriting masked elements
        pos_enc = pos_enc.clone().masked_fill_(
            empty_mask[:, :, 0].unsqueeze(2), 0.0
        )

        # Sum pooling
        pos_enc = torch.sum(pos_enc, 1, keepdim=False)  # (Num nodes) x dim_pe

        # MLP post pooling
        if self.post_mlp is not None:
            if (
                cfg.slt.sm is True or cfg.slt.mm is True
            ) and cfg.slt.encoder is True:
                if cfg.slt.pruning == "blockwise":
                    pos_enc = self.post_mlp(pos_enc, encoder_th)
                elif cfg.slt.pruning == "global":
                    pos_enc = self.post_mlp(pos_enc, global_th)
            else:
                pos_enc = self.post_mlp(pos_enc)  # (Num nodes) x dim_pe

        # Expand node features if needed
        if self.expand_x:
            if (
                cfg.slt.sm is True or cfg.slt.mm is True
            ) and cfg.slt.encoder is True:
                if cfg.slt.pruning == "blockwise":
                    h = self.linear_x(batch.x, encoder_th)
                elif cfg.slt.pruning == "global":
                    h = self.linear_x(batch.x, global_th)
            else:
                h = self.linear_x(batch.x)
        else:
            h = batch.x
        # Concatenate final PEs to input embedding
        batch.x = torch.cat((h, pos_enc), 1)
        # Keep PE also separate in a variable (e.g. for skip connections to input)
        if self.pass_as_var:
            batch.pe_LapPE = pos_enc
        return batch
