# _*_ coding: utf-8 _*_

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import functional as F

def simple_elementwise_apply(fn, packed_sequence):
    """applies a pointwise function fn to each element in packed_sequence"""
    return torch.nn.utils.rnn.PackedSequence(fn(packed_sequence.data), packed_sequence.batch_sizes)

class LSTMClassifier(nn.Module):
    def __init__(self, output_size, hidden_size, vocab_size, embedding_length, wordske):
        super(LSTMClassifier, self).__init__()

        """
        Arguments
        ---------
        batch_size : Size of the batch which is same as the batch_size of the data returned by the TorchText BucketIterator
        output_size : 2 = (pos, neg)
        hidden_sie : Size of the hidden_state of the LSTM
        vocab_size : Size of the vocabulary containing unique words
        embedding_length : Embeddding dimension of GloVe word embeddings
        weights : Pre-trained GloVe word_embeddings which we will use to create our word_embedding look-up table 

        """

        self.output_size = output_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.num_layer = 2
        self.embedding_length = embedding_length
        if wordske == "word":
            self.word_embeddings = nn.Embedding(vocab_size, embedding_length, padding_idx=0)  # Initializing the look-up table.
        elif wordske == "ske":
            self.word_embeddings = nn.Linear(17*3, embedding_length)  # Initializing the look-up table.
        else:
            raise NotImplementedError
        self.lstm = nn.LSTM(embedding_length, hidden_size, num_layers=self.num_layer)
        # self.label = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Linear(hidden_size, output_size))
        self.label = nn.Linear(hidden_size, output_size)

    def forward(self, input_sentence, lengths):

        """
        Parameters
        ----------
        input_sentence: input_sentence of shape = (batch_size, num_sequences)
        batch_size : default = None. Used only for prediction on a single sentence after training (batch_size = 1)

        Returns
        -------
        Output of the linear layer containing logits for positive & negative class which receives its input as the final_hidden_state of the LSTM
        final_output.shape = (batch_size, output_size)

        """
        bs = input_sentence.shape[1]
        input = self.word_embeddings(input_sentence)  # embedded input of shape = (num_sequences, batch_size,  embedding_length)
        input = torch.nn.utils.rnn.pack_padded_sequence(input, lengths, enforce_sorted=False)
        h_0 = torch.zeros(self.num_layer, bs, self.hidden_size).cuda()  # Initial hidden state of the LSTM
        c_0 = torch.zeros(self.num_layer, bs, self.hidden_size).cuda()  # Initial cell state of the LSTM
        output, (final_hidden_state, final_cell_state) = self.lstm(input, (h_0, c_0))

        final_output = self.label(final_hidden_state[-1])  # final_hidden_state.size() = (#layer, batch_size, hidden_size) & final_output.size() = (batch_size, output_size)

        return final_output
