import torch
import torch.nn as nn
from torch.autograd import Variable

class SeqEncoder(nn.Module):
	def __init__(self, args):
		super(SeqEncoder, self).__init__()
		self.prog_vocab_size = args.prog_vocab_size
		self.max_prog_len = args.max_prog_len
		self.prog_embedding_size = args.prog_embedding_size
		self.num_layers = args.num_RNN_layers
		self.hidden_size = args.hidden_size
		self.dropout_rate = args.dropout_rate
		self.cuda_flag = args.cuda

		self.encoder_embedding = nn.Embedding(self.prog_vocab_size, self.prog_embedding_size)
		self.encoder = nn.LSTM(input_size=self.prog_embedding_size,
			hidden_size=self.hidden_size,
			num_layers=self.num_layers,
			batch_first=True,
			dropout=self.dropout_rate)

	def forward(self, encoder_inputs):
		embedding = self.encoder_embedding(encoder_inputs)
		batch_size = encoder_inputs.size()[0]
		init_h = Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size))
		init_c = Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size))
		if self.cuda_flag:
			init_h = init_h.cuda()
			init_c = init_c.cuda()
		init_state = (init_h, init_c)
		encoder_outputs, encoder_state = self.encoder(embedding, init_state)
		return encoder_outputs, encoder_state


