import torch
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import torch.nn as nn

import math

EMBEDDING_DIM = 200
HIDDEN_DIM = 1024
OUTPUT_DIM = 2
N_HEADS = 2
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.25


class PositionalEmbedding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens
        in the sequence. The positional encodings have the same dimension as
        the embeddings, so that the two can be summed. Here, we use sine and cosine
        functions of different frequencies.
    .. math::
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEmbedding(d_model)
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [batch size, sequence length, embed dim]
            output: [batch size, sequence length, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """
        x = x + self.pe[:, : x.shape[1], :]
        return self.dropout(x)


class TransformerModel(nn.Module):
    def __init__(self, vocab_size, classes):
        super().__init__()
        self.model_type = "Transformer"
        self.src_mask = None
        self.pos_encoder = PositionalEmbedding(EMBEDDING_DIM, DROPOUT)
        encoder_layers = TransformerEncoderLayer(EMBEDDING_DIM, N_HEADS, HIDDEN_DIM, DROPOUT, batch_first=True)
        self.transformer_encoder = TransformerEncoder(encoder_layers, N_LAYERS)
        self.encoder = nn.Embedding(vocab_size, EMBEDDING_DIM)
        self.EMBEDDING_DIM = EMBEDDING_DIM
        self.decoder = nn.Linear(EMBEDDING_DIM, classes)

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.encoder.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.weight)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def forward(self, input_ids, has_mask=False, inputs_embeds=None, **kwargs):
        with torch.no_grad():
            inputs = self.encoder(input_ids) * math.sqrt(EMBEDDING_DIM)
        inputs = self.pos_encoder(inputs)
        output = self.transformer_encoder(inputs, self.src_mask)
        output = self.decoder(output)
        output = output[:, 0, :]

        return output

    def features(self, input_ids):

        return torch.zeros(input_ids.shape[0], 1, 1, 1)

class RNNModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""
    import torch.nn as nn

    def __init__(self, vocab_size):

        super().__init__()

        self.embedding = nn.Embedding(vocab_size, EMBEDDING_DIM, padding_idx=0)

        self.lstm = nn.GRU(EMBEDDING_DIM, 
                           HIDDEN_DIM, 
                           num_layers=N_LAYERS, 
                           bidirectional=BIDIRECTIONAL, 
                           dropout=DROPOUT, batch_first=True)
        
        if BIDIRECTIONAL:
            self.fc = nn.Linear(2 * HIDDEN_DIM, OUTPUT_DIM)
        else:
            self.fc = nn.Linear(HIDDEN_DIM, OUTPUT_DIM)
        
        # self.dropout = nn.Dropout(DROPOUT)

    def forward(self, text):
        # batch_size = text.size(0)

        # embeddings and lstm_out
        
        embeds = self.embedding(text)
        lstm_out, _ = self.lstm(embeds)
        
        # stack up lstm outputs
        # lstm_out = lstm_out.contiguous().view(-1, HIDDEN_DIM)
        
        # dropout and fully connected layer
        # out = self.dropout(lstm_out[:, -1, :])
        out = lstm_out[:, -1, :]
        out = self.fc(out)

        return out

    def features(self, text):

        return torch.zeros(text.shape[0], 1, 1, 1)