import warnings

import torch
import torch.nn as nn
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_node_encoder
from torchvision.ops import MLP
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_node_encoder
from performer_pytorch import SelfAttention
from dirgt.layer.transformer_layer import *

EPS = 1E-8


@register_node_encoder('SpecPosEnc')
class SpecPosEnc(torch.nn.Module):
    def __init__(self,
                 dim_emb):
        super(SpecPosEnc, self).__init__()

        self.encoders = cfg.dataset.node_encoders
        if "LinearNode" in self.encoders:
            print("Using LinearNode encoder")
            self.linearEnc = LinearNodeEncoder(cfg.share.dim_in, dim_emb)
        if "+LapPE" in self.encoders:
            print("Using LapPE encoder")
            self.lapEnc = LaplacianEncoder(dim_emb)
        if "+MagLapPE" in self.encoders:
            print("Using MagLapPE encoder")
            self.magLapEnc = MagLapEncoder(dim_emb)

        self.fbaEnc = None
        if cfg.dataset.name == 'MColiDataset':
            # this is done when fineturning with mcoli dataset
            # the extra feature is added encoded as input
            self.fbaEnc = FBAEncoder(1, dim_emb)


    def forward(self,
                batch):
        encoders = self.encoders.split("+")
        if "LinearNode" in encoders:
            x = self.linearEnc(batch)
        else:
            x = batch.x
        if "LapPE" in encoders:
            x = x + self.lapEnc(batch)
        if "MagLapPE" in encoders:
            x = x + self.magLapEnc(batch)
        if self.fbaEnc is not None:
            x = x + self.fbaEnc(batch)
        batch.x = x
        return batch

class FBAEncoder(torch.nn.Module):
    def __init__(self, dim_in, emb_dim):
        super().__init__()
        self.encoder = torch.nn.Linear(dim_in, emb_dim)

    def forward(self, batch):
        x = self.encoder(batch.fba_pred)
        return x

class LinearNodeEncoder(torch.nn.Module):
    def __init__(self, dim_in, emb_dim):
        super().__init__()
        self.encoder = torch.nn.Linear(dim_in, emb_dim)

    def forward(self, batch):
        x = self.encoder(batch.x)
        return x

class MagLapEncoder(torch.nn.Module):
    def __init__(self, dim_emb):
        super(MagLapEncoder, self).__init__()

        d_model_elem = cfg.posenc_MagLapPE.d_model_elem
        d_model_aggr = cfg.posenc_MagLapPE.d_model_aggr
        num_heads = cfg.posenc_MagLapPE.num_heads
        n_layers = cfg.posenc_MagLapPE.n_layers
        attn_layers = cfg.posenc_MagLapPE.attn_layers
        dropout_p = cfg.posenc_MagLapPE.dropout_p
        dropout_attn = cfg.posenc_MagLapPE.dropout_attn
        return_real_output = cfg.posenc_MagLapPE.return_real_output
        consider_im_part = cfg.posenc_MagLapPE.consider_im_part
        use_signnet = cfg.posenc_MagLapPE.use_signnet
        use_gnn = cfg.posenc_MagLapPE.use_gnn
        use_attention_pre_aggr = cfg.posenc_MagLapPE.use_attention_pre_aggr
        use_attention_post_aggr = cfg.posenc_MagLapPE.use_attention_post_aggr
        concatenate_eigenvalues= cfg.posenc_MagLapPE.concatenate_eigenvalues
        norm = cfg.posenc_MagLapPE.norm
        max_freqs = cfg.posenc_MagLapPE.eigen.max_freqs

        self.concatenate_eigenvalues = concatenate_eigenvalues
        self.consider_im_part = consider_im_part
        self.use_signnet = use_signnet
        self.use_gnn = use_gnn
        self.use_attention_pre_aggr = use_attention_pre_aggr
        self.use_attention_post_aggr = use_attention_post_aggr
        self.num_heads = num_heads
        self.dropout_p = dropout_p

        if self.use_gnn:
            raise NotImplementedError("GNN not implemented for MagLapEncoder")
        else:
            dim = int(2 * d_model_elem) if self.consider_im_part else d_model_elem
            self.element_mlp = MLP(
                2,
                [dim] * n_layers)

        dim = (2 * d_model_elem) + 1 if self.concatenate_eigenvalues else (2 * d_model_elem)

        self.re_aggregate_mlp = MLP(
            dim * max_freqs,
            [d_model_aggr] * n_layers)

        self.im_aggregate_mlp = None
        if not return_real_output and self.consider_im_part:
            self.im_aggregate_mlp = MLP(
                dim * max_freqs,
                [d_model_aggr] * n_layers)

        if use_attention_pre_aggr:
            encoder_layer = nn.TransformerEncoderLayer(d_model=dim,
                                                       nhead=num_heads,
                                                       batch_first=True,
                                                       dropout=dropout_attn)
            if attn_layers > 1:
                self.attn_pre = nn.TransformerEncoder(encoder_layer,
                                                    num_layers=n_layers)
            else:
                self.attn_pre = encoder_layer

        if norm:
            self.norm = torch.nn.LayerNorm(dim)
        else:
            self.norm = None

        if use_attention_post_aggr:
            encoder_layer = nn.TransformerEncoderLayer(d_model=d_model_aggr,
                                                       dim_feedforward=d_model_aggr,
                                                       nhead=num_heads,
                                                       batch_first=True,
                                                       dropout=dropout_attn)
            if attn_layers > 1:
                self.attn_post = nn.TransformerEncoder(encoder_layer,
                                                  num_layers=n_layers)
            else:
                self.attn_post = encoder_layer


    def forward(self,
                batch):
        if not (hasattr(batch, 'MagEigVals') and hasattr(batch, 'MagEigVecs')):
            raise ValueError("Precomputed eigen values and vectors are "
                             f"required for {self.__class__.__name__}; "
                             "set config 'posenc_MagLapPE.enable' to True")

        eigenvalues = batch.MagEigVals
        eigenvectors = batch.MagEigVecs

        padding_mask = (eigenvalues > 0)
        padding_mask[..., 0] = True

        trans_eig = eigenvectors.real
        trans_eig = trans_eig[..., None]

        if self.consider_im_part and torch.is_complex(eigenvectors):
            trans_eig_im = eigenvectors.imag[..., None]
            trans_eig = torch.cat((trans_eig, trans_eig_im), dim=-1)

        if self.use_gnn:
            pass
        else:
            trans = self.element_mlp(trans_eig)
            if self.use_signnet:
                aux = self.element_mlp(-trans_eig)
                trans = trans + aux

        if self.concatenate_eigenvalues:
            eigenvalues_ = eigenvalues[..., None]
            trans = torch.cat((eigenvalues_, trans), dim=-1)

        if self.use_attention_pre_aggr:
            if self.norm is not None:
                trans = self.norm()(trans)
            empty_mask = torch.isnan(trans)
            attn_output = self.attn_pre(src=trans,
                                      src_key_padding_mask=empty_mask[:, :, 0].t())
            trans = trans + attn_output

        trans = trans.reshape(trans.shape[:-2] + (-1,))

        if self.dropout_p:
            trans = nn.Dropout(p=self.dropout_p)(trans)

        output = self.re_aggregate_mlp(trans)

        if self.im_aggregate_mlp is None:
            pass
        else:
            output_im = self.im_aggregate_mlp(trans)
            output = output + 1j * output_im

        if self.use_attention_post_aggr:
            empty_mask = torch.isnan(output)
            output = output[None, :]
            key_mask = empty_mask[None, :, 0]
            attn_output = self.attn_post(src=output,
                                    src_key_padding_mask=key_mask)
            output = output + attn_output

        return output[0]  # only one batch, we take the first element


class LaplacianEncoder(torch.nn.Module):

    def __init__(self, dim_emb):
        super().__init__()

        pecfg = cfg.posenc_LapPE
        dim_pe = pecfg.dim_pe  # Size of Laplace PE embedding
        model_type = pecfg.model  # Encoder NN model type for PEs
        if model_type not in ['Transformer', 'DeepSet']:
            raise ValueError(f"Unexpected PE model {model_type}")
        self.model_type = model_type
        n_layers = pecfg.layers  # Num. layers in PE encoder model
        n_heads = pecfg.n_heads  # Num. attention heads in Trf PE encoder
        post_n_layers = pecfg.post_layers  # Num. layers to apply after pooling
        max_freqs = pecfg.eigen.max_freqs  # Num. eigenvectors (frequencies)
        norm_type = pecfg.raw_norm_type.lower()  # Raw PE normalization layer type
        self.pass_as_var = pecfg.pass_as_var  # Pass PE also as a separate variable

        # Initial projection of eigenvalue and the node's eigenvector value
        self.linear_A = nn.Linear(2, dim_pe)
        if norm_type == 'batchnorm':
            self.raw_norm = nn.BatchNorm1d(max_freqs)
        else:
            self.raw_norm = None

        activation = nn.ReLU  # register.act_dict[cfg.gnn.act]
        if model_type == 'Transformer':
            # Transformer model for LapPE
            self.pe_encoder = nn.TransformerEncoderLayer(d_model=dim_pe,
                                                       dim_feedforward=dim_pe,
                                                       nhead=n_heads,
                                                       batch_first=True)

        else:
            # DeepSet model for LapPE
            layers = []
            if n_layers == 1:
                layers.append(activation())
            else:
                self.linear_A = nn.Linear(2, 2 * dim_pe)
                layers.append(activation())
                for _ in range(n_layers - 2):
                    layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe))
                    layers.append(activation())
                layers.append(nn.Linear(2 * dim_pe, dim_pe))
                layers.append(activation())
            self.pe_encoder = nn.Sequential(*layers)

        self.post_mlp = None
        if post_n_layers > 0:
            # MLP to apply post pooling
            layers = []
            if post_n_layers == 1:
                layers.append(nn.Linear(dim_pe, dim_pe))
                layers.append(activation())
            else:
                layers.append(nn.Linear(dim_pe, 2 * dim_pe))
                layers.append(activation())
                for _ in range(post_n_layers - 2):
                    layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe))
                    layers.append(activation())
                layers.append(nn.Linear(2 * dim_pe, dim_pe))
                layers.append(activation())
            self.post_mlp = nn.Sequential(*layers)


    def forward(self, batch):
        if not (hasattr(batch, 'EigVals') and hasattr(batch, 'EigVecs')):
            raise ValueError("Precomputed eigen values and vectors are "
                             f"required for {self.__class__.__name__}; "
                             "set config 'posenc_LapPE.enable' to True")
        EigVals = batch.EigVals
        EigVecs = batch.EigVecs

        if self.training:
            sign_flip = torch.rand(EigVecs.size(1), device=EigVecs.device)
            sign_flip[sign_flip >= 0.5] = 1.0
            sign_flip[sign_flip < 0.5] = -1.0
            EigVecs = EigVecs * sign_flip.unsqueeze(0)

        pos_enc = torch.cat((EigVecs.unsqueeze(2), EigVals.unsqueeze(2)), dim=2) # (Num nodes) x (Num Eigenvectors) x 2
        empty_mask = torch.isnan(pos_enc)  # (Num nodes) x (Num Eigenvectors) x 2

        pos_enc[empty_mask] = 0  # (Num nodes) x (Num Eigenvectors) x 2
        if self.raw_norm:
            pos_enc = self.raw_norm(pos_enc)
        pos_enc = self.linear_A(pos_enc)  # (Num nodes) x (Num Eigenvectors) x dim_pe

        # PE encoder: a Transformer or DeepSet model
        if self.model_type == 'Transformer':
            pos_enc = self.pe_encoder(src=pos_enc,
                                      src_key_padding_mask=empty_mask[:, :, 0])
        else:
            pos_enc = self.pe_encoder(pos_enc)

        # Remove masked sequences; must clone before overwriting masked elements
        m = empty_mask[:, :, 0].unsqueeze(2)
        pos_enc = pos_enc.clone().masked_fill_(m,
                                               0.)

        # Sum pooling
        pos_enc = torch.mean(pos_enc, 1, keepdim=False)  # (Num nodes) x dim_pe

        # MLP post pooling
        if self.post_mlp is not None:
            pos_enc = self.post_mlp(pos_enc)  # (Num nodes) x dim_pe

        return pos_enc
