import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer


class TransformerModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, nclasses, idropout=0.1, hdropout=0.5, 
                 return_outputs=False, graph_classification=False, return_outputs_2=False, 
                 ignore_padding=False, layer_norm=0, use_label=0, ldropout=0.5, src_scale=0):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'

        if ninp > 0:
            self.encoder = nn.Linear(ntoken, ninp)
        else:
            self.encoder = None
            ninp = nhid = ntoken
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, hdropout, activation='gelu')
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.ninp = ninp
        self.link_pred = return_outputs
        self.return_outputs_2 = return_outputs_2
        if not return_outputs:
            self.decoder = nn.Linear(ninp, nclasses)
        if return_outputs_2:
            # segment embedding for ego nodes
            self.ego_embedding = nn.Embedding(2, ninp)
        self.graph_classification = graph_classification
        self.ignore_padding = ignore_padding
        self.dropout = nn.Dropout(p=idropout)
        self.ldropout = nn.Dropout(p=ldropout)
        self.src_scale = src_scale
        self.layer_norm = layer_norm
        if layer_norm:
            self.input_layer_norm = nn.LayerNorm(ninp)
        if use_label:
            self.label_encoder = nn.Linear(nclasses, ntoken)

    def init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, src_mask=None, padding=None, pe=None, src_label=None):
        if src_label is not None:
            src_label = self.ldropout(src_label)
            src = src + self.label_encoder(src_label)
        src = src.transpose(0, 1)
        if self.encoder is not None:
            src = self.encoder(src)
        if self.src_scale:
            src = src * math.sqrt(self.ninp)
        if pe is not None:
            pe = pe.transpose(0, 1)
            src = src + pe
        if self.return_outputs_2:
            e0 = self.ego_embedding(torch.LongTensor([0]).to(src.device)).squeeze()
            e1 = self.ego_embedding(torch.LongTensor([1]).to(src.device)).squeeze()
            src += e1
            src[0] += e0 - e1
            src[src.shape[0] // 2] += e0 - e1
        if self.layer_norm:
            src = self.input_layer_norm(src)
        src = self.dropout(src)
        output = self.transformer_encoder(src, src_mask, src_key_padding_mask=padding)
        if self.link_pred:
            if self.return_outputs_2:
                return output[0], output[src.shape[0] // 2]
            else:
                return output[0]
        if self.graph_classification:
            if self.ignore_padding:
                nonpadding = ~padding.unsqueeze(-1)
                output = (output.transpose(0, 1) * nonpadding).sum(dim=1) / nonpadding.sum(dim=1)
                output = self.decoder(output)
            else:
                output = self.decoder(output.mean(dim=0))
        else:
            output = self.decoder(output[0])

        return output
