import torch
import torch.nn as nn


class Laplacian(torch.nn.Module):
    """Laplace Positional Embedding node encoder.

    LapPE of size dim_pe will get appended to each node feature vector.
    If `expand_x` set True, original node features will be first linearly
    projected to (dim_emb - dim_pe) size and the concatenated with LapPE.

    Args:
        dim_emb: Size of final node embedding
        expand_x: Expand node features `x` from dim_in to (dim_emb - dim_pe)
    """

    def __init__(self, cfg):
        super().__init__()

        self.dim_pe = cfg.model.pos_encoder.emb_dim  # Size of PE embedding
        model_type = cfg.model.pos_encoder.params.model  # Encoder NN model type for PEs
        if model_type not in ["Transformer", "DeepSet"]:
            raise ValueError(f"Unexpected PE model {model_type}")
        self.model_type = model_type
        n_layers = cfg.model.pos_encoder.params.layers  # Num. layers in PE encoder model
        n_heads = cfg.model.pos_encoder.params.n_heads  # Num. attention heads in Trf PE encoder
        post_n_layers = cfg.model.pos_encoder.params.post_layers  # Num. layers to apply after pooling
        max_freqs = cfg.model.num_eigenvecs  # Num. eigenvectors (frequencies)
        norm_type = cfg.model.pos_encoder.params.raw_norm_type.lower()  # Raw PE normalization layer type

        # Initial projection of eigenvalue and the node's eigenvector value
        self.linear_A = nn.Linear(2, self.dim_pe)
        if norm_type == "batchnorm":
            self.raw_norm = nn.BatchNorm1d(max_freqs)
        else:
            self.raw_norm = None

        activation = nn.ReLU  # register.act_dict[cfg.gnn.act]
        if model_type == "Transformer":
            # Transformer model for LapPE
            encoder_layer = nn.TransformerEncoderLayer(d_model=self.dim_pe, nhead=n_heads, batch_first=True)
            self.pe_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        else:
            # DeepSet model for LapPE
            layers = []
            if n_layers == 1:
                layers.append(activation())
            else:
                self.linear_A = nn.Linear(2, 2 * self.dim_pe)
                layers.append(activation())
                for _ in range(n_layers - 2):
                    layers.append(nn.Linear(2 * self.dim_pe, 2 * self.dim_pe))
                    layers.append(activation())
                layers.append(nn.Linear(2 * self.dim_pe, self.dim_pe))
                layers.append(activation())
            self.pe_encoder = nn.Sequential(*layers)

        self.post_mlp = None
        if post_n_layers > 0:
            # MLP to apply post pooling
            layers = []
            if post_n_layers == 1:
                layers.append(nn.Linear(self.dim_pe, self.dim_pe))
                layers.append(activation())
            else:
                layers.append(nn.Linear(self.dim_pe, 2 * self.dim_pe))
                layers.append(activation())
                for _ in range(post_n_layers - 2):
                    layers.append(nn.Linear(2 * self.dim_pe, 2 * self.dim_pe))
                    layers.append(activation())
                layers.append(nn.Linear(2 * self.dim_pe, self.dim_pe))
                layers.append(activation())
            self.post_mlp = nn.Sequential(*layers)

    def forward(self, eigvecs, batch):

        eigvecs = eigvecs  # (Num nodes) x (Num Eigenvecs)
        eigvals = batch["eigenvals"].repeat(len(eigvecs), 1)  # (Num nodes) x (Num Eigenvecs)
        assert eigvecs.ndim == eigvals.ndim
        assert eigvecs.size(0) == eigvals.size(0)
        assert eigvecs.size(1) == eigvals.size(1)

        if self.training:
            sign_flip = torch.rand(eigvecs.size(1), device=eigvecs.device)
            sign_flip[sign_flip >= 0.5] = 1.0
            sign_flip[sign_flip < 0.5] = -1.0
            eigvecs = eigvecs * sign_flip.unsqueeze(0)

        if eigvals.dtype == torch.float64:
            eigvals = eigvals.float()
        pos_enc = torch.cat((eigvecs.unsqueeze(2), eigvals.unsqueeze(2)), dim=2)  # (Num nodes) x (Num Eigenvectors) x 2
        empty_mask = torch.isnan(pos_enc)  # (Num nodes) x (Num Eigenvectors) x 2

        pos_enc[empty_mask] = 0  # (Num nodes) x (Num Eigenvectors) x 2
        if self.raw_norm:
            pos_enc = self.raw_norm(pos_enc)
        pos_enc = self.linear_A(pos_enc)  # (Num nodes) x (Num Eigenvectors) x dim_pe

        # PE encoder: a Transformer or DeepSet model
        if self.model_type == "Transformer":
            pos_enc = self.pe_encoder(src=pos_enc, src_key_padding_mask=empty_mask[:, :, 0])
        else:
            pos_enc = self.pe_encoder(pos_enc)

        # Remove masked sequences; must clone before overwriting masked elements
        pos_enc = pos_enc.clone().masked_fill_(empty_mask[:, :, 0].unsqueeze(2), 0.0)

        # Sum pooling
        pos_enc = torch.sum(pos_enc, 1, keepdim=False)  # (Num nodes) x dim_pe

        # MLP post pooling
        if self.post_mlp is not None:
            pos_enc = self.post_mlp(pos_enc)  # (Num nodes) x dim_pe

        return pos_enc
