import numpy as np
import operator
import random
import time

import torch
import torch.nn as nn
import torch.optim as optim
from torch import cuda
from torch.autograd import Variable
from torch.nn.utils import clip_grad_norm
import torch.nn.functional as F


class ProgDecoder(nn.Module):
	def __init__(self, args):
		super(ProgDecoder, self).__init__()
		self.cuda_flag = args.cuda
		self.io_embedding_size = args.io_embedding_size
		self.prog_embedding_size = args.prog_embedding_size
		self.num_layers = args.num_RNN_layers
		self.input_size = self.io_embedding_size + self.prog_embedding_size
		self.hidden_size = args.hidden_size
		self.dropout_rate = args.dropout_rate
		self.prog_vocab_size = args.prog_vocab_size

		self.dropout = nn.Dropout(p=self.dropout_rate)
		self.prog_embedding = nn.Embedding(self.prog_vocab_size, self.prog_embedding_size)
		self.decoder = nn.LSTM(input_size=self.input_size,
			hidden_size=self.hidden_size,
			num_layers=self.num_layers,
			batch_first=True,
			dropout=self.dropout_rate)
		self.output_linear_layer = nn.Linear(self.hidden_size, self.prog_vocab_size)
		self.output_softmax_layer = nn.LogSoftmax()

	def predict(self, decoder_output):
		output = self.dropout(decoder_output)
		output_linear = self.output_linear_layer(output)
		output_softmax = self.output_softmax_layer(output_linear)
		return output_softmax

	def forward(self, io_embedding, prog_token, prog_state):
		init_state = prog_state
		prog_token_embedding = self.prog_embedding(prog_token.unsqueeze(1))
		io_embedding = io_embedding.unsqueeze(1)
		decoder_input = torch.cat([io_embedding, prog_token_embedding], 2)
		output, state = self.decoder(decoder_input, init_state)
		output_squeeze = output.squeeze()
		if len(output_squeeze.size()) == 1:
			output_squeeze = output_squeeze.unsqueeze(0)
		prediction = self.predict(output_squeeze)
		
		return prediction, state

