import numpy as np
import operator
import random
import time

import torch
import torch.nn as nn
import torch.optim as optim
from torch import cuda
from torch.autograd import Variable
from torch.nn.utils import clip_grad_norm
import torch.nn.functional as F

from .model_utils import supervisor
from .data_utils import data_utils
from .modules import ioencoder, progencoder, progdecoder
from .simulators import karel_simulator


class ProgGenModel(nn.Module):
	def __init__(self, args):
		super(ProgGenModel, self).__init__()
		self.io_encoder = ioencoder.IOEncoder(args)
		self.prog_decoder = progdecoder.ProgDecoder(args)

		self.loss_function = nn.NLLLoss()

	def forward(self, input_grids, output_grids, prog_tokens, decoder_state):
		io_embeddings = self.io_encoder(input_grids, output_grids)
		pool_io_embeddings, _ = io_embeddings.max(1)
		predictions, state = self.prog_decoder(pool_io_embeddings, prog_tokens, decoder_state)
		return predictions, state

	def compute_loss(self, predictions, target):
		total_loss = self.loss_function(predictions, target)
		return total_loss

class KarelModel(nn.Module):
	def __init__(self, args, prog_vocab, prog_vocab_list):
		super(KarelModel, self).__init__()
		self.prog_vocab_size = args.prog_vocab_size
		self.prog_vocab = prog_vocab
		self.prog_vocab_list = prog_vocab_list
		self.max_prog_len = args.max_prog_len + 5
		self.io_embedding_size = args.io_embedding_size
		self.io_count = args.io_count
		self.hidden_size = args.hidden_size
		self.num_RNN_layers = args.num_RNN_layers
		self.gradient_clip = args.gradient_clip
		self.lr = args.lr
		self.dropout_rate = args.dropout_rate
		self.dropout = nn.Dropout(p=self.dropout_rate)
		self.grid_size = args.grid_size
		self.io_feature_size = args.io_feature_size
		self.cuda_flag = args.cuda
		self.sample_bias = args.sample_bias
		self.eval_mode = args.eval_mode
		self.beam_size = args.beam_size
		self.internal_beam_size = args.beam_size * 2
		self.search_depth = args.search_depth
		self.baseline = args.baseline
		self.best_predicted_progs = []
		self.best_prob = []
		self.best_vf = []
		self.min_prob_limit = 0.0
		if args.compatibility_mode:
			self.policy_model = ProgGenModel(args)
		else:
			self.model = ProgGenModel(args)
		self.simulator = karel_simulator.KarelSimulator(args, prog_vocab, prog_vocab_list)
		
		if args.optimizer == 'adam':
			self.optimizer = optim.Adam(self.parameters(), lr=self.lr)
		elif args.optimizer == 'sgd':
			self.optimizer = optim.SGD(self.parameters(), lr=self.lr)
		elif args.optimizer == 'rmsprop':
			self.optimizer = optim.RMSprop(self.parameters(), lr=self.lr, alpha=0.95)
		else:
			print('optimizer undefined: ', args.optimizer)			

	@property
	def model(self):
		return self.policy_model

	def init_weights(self, param_init):
		for param in self.parameters():
			param.data.uniform_(-param_init, param_init)

	def lr_decay(self, lr_decay_rate):
		self.lr *= lr_decay_rate
		for param_group in self.optimizer.param_groups:
			param_group['lr'] = self.lr

	def train_step(self):
		if self.gradient_clip > 0:
			clip_grad_norm(self.parameters(), self.gradient_clip)
		self.optimizer.step()

	def forward(self, batch_data, feed_previous=False):
		train_input_grids, train_output_grids, val_input_grids, val_output_grids, progs = self.preprocess_data(batch_data)
		return self.run_prog_gen(train_input_grids, train_output_grids, val_input_grids, val_output_grids, progs, feed_previous)

	def search(self, syntax_checker_state, init_prob, init_input_grids, init_karel_worlds, init_decoder_state, output_grids, output_karel_worlds, current_depth, cum_prob, prog_prefix, gt_prog=None):
		if current_depth <= 0:
			return
		if len(init_prob) == 0:
			return
		new_syntax_checker_state = []
		new_init_prob = []
		new_init_input_grids = []
		new_init_karel_worlds = []
		new_init_decoder_state = []
		new_output_grids = []
		new_output_karel_worlds = []
		new_cum_prob = []
		new_prog_prefix = []

		for idx in range(len(init_prob)):		
			current_prob = init_prob[idx] * syntax_checker_state[idx].mask
			candidate_predictions = np.argsort(current_prob, axis=0)
		
			for prog_token in candidate_predictions[::-1]:
				if current_prob[prog_token] <= self.sample_bias:
					break

				cur_cum_prob = cum_prob[idx] * current_prob[prog_token]
				cur_predicted_prog = prog_prefix[idx] + [prog_token]
				if len(self.best_prob) >= self.beam_size and cur_cum_prob <= self.min_prob_limit:
					break
				if len(new_cum_prob) >= self.internal_beam_size and cur_cum_prob <= new_cum_prob[-1]:
					break

				if prog_token == self.simulator.syntax_checker.syntax_vocab.m_close_token:
					crash = False
					input_grids = []
					karel_worlds = []
					cur_syntax_checker_state = None
					for io_idx in range(self.io_count):
						cur_karel_world, cur_syntax_checker_state = self.simulator.execute(init_karel_worlds[idx][io_idx], syntax_checker_state[idx], prog_token)
						if cur_karel_world.crash:
							crash = True
						input_grids.append(cur_karel_world.grid)
						karel_worlds.append(cur_karel_world)
					input_grids = np.array(input_grids)
					input_grids = data_utils.np_to_tensor(input_grids, 'float', cuda_flag=self.cuda_flag, volatile_flag=True)
					input_grids = input_grids.unsqueeze(0)

					consistent = True
					for io_idx in range(self.io_count):
						if karel_worlds[io_idx] != output_karel_worlds[idx][io_idx]:
							consistent = False
							break

					if consistent:
						cur_vf = 1.0

						cur_idx = 0
						while cur_idx < len(self.best_vf):
							if self.best_vf[cur_idx] > cur_vf or self.best_vf[cur_idx] == cur_vf and self.best_prob[cur_idx] > cur_cum_prob:
								cur_idx += 1
							else:
								break
					else:
						cur_vf = 0.0

						cur_idx = len(self.best_vf) - 1
						while cur_idx >= 0:
							if self.best_vf[cur_idx] < cur_vf or self.best_vf[cur_idx] == cur_vf and self.best_prob[cur_idx] < cur_cum_prob:
								cur_idx -= 1
							else:
								break
						cur_idx += 1
						
					self.best_predicted_progs = self.best_predicted_progs[:cur_idx] + [cur_predicted_prog] + self.best_predicted_progs[cur_idx:]
					self.best_prob = self.best_prob[:cur_idx] + [cur_cum_prob] + self.best_prob[cur_idx:]
					self.best_vf = self.best_vf[:cur_idx] + [cur_vf] + self.best_vf[cur_idx:]

					if len(self.best_prob) > self.beam_size:
						self.best_predicted_progs = self.best_predicted_progs[:self.beam_size]
						self.best_prob = self.best_prob[:self.beam_size]
						self.best_vf = self.best_vf[:self.beam_size]
					if cur_vf == 1.0:
						self.min_prob_limit = max(self.min_prob_limit, cur_cum_prob)
					if len(self.best_prob) >= self.beam_size:
						self.min_prob_limit = max(self.min_prob_limit, self.best_prob[-1])
					continue

				crash = False
				input_grids = []
				karel_worlds = []
				cur_syntax_checker_state = None

				for io_idx in range(self.io_count):
					cur_karel_world, cur_syntax_checker_state = self.simulator.execute(init_karel_worlds[idx][io_idx], syntax_checker_state[idx], prog_token)
					if cur_karel_world.crash:
						crash = True
					input_grids.append(cur_karel_world.grid)
					karel_worlds.append(cur_karel_world)

				if crash:
					continue
					
				input_grids = np.array(input_grids)
				input_grids = data_utils.np_to_tensor(input_grids, 'float', cuda_flag=self.cuda_flag, volatile_flag=True)
				input_grids = input_grids.unsqueeze(0)

				current_prog_token = np.array([prog_token])
				current_prog_token = data_utils.np_to_tensor(current_prog_token, 'int', cuda_flag=self.cuda_flag, volatile_flag=True)
				if self.baseline:
					decoder_output, decoder_state = self.model(init_input_grids[idx], output_grids[idx], current_prog_token, init_decoder_state[idx])
				else:
					decoder_output, decoder_state = self.model(input_grids, output_grids[idx], current_prog_token, init_decoder_state[idx])
				next_prob = np.exp(decoder_output[0].data.cpu().numpy())
				next_prob = next_prob + self.sample_bias

				cur_idx = len(new_cum_prob) - 1
				while cur_idx >= 0:
					if new_cum_prob[cur_idx] < cur_cum_prob:
						cur_idx -= 1
					else:
						break
				cur_idx += 1

				new_syntax_checker_state = new_syntax_checker_state[:cur_idx] + [cur_syntax_checker_state] + new_syntax_checker_state[cur_idx:]
				new_init_prob = new_init_prob[:cur_idx] + [next_prob] + new_init_prob[cur_idx:]
				new_init_input_grids = new_init_input_grids[:cur_idx] + [init_input_grids[idx]] + new_init_input_grids[cur_idx:]
				new_init_karel_worlds = new_init_karel_worlds[:cur_idx] + [karel_worlds] + new_init_karel_worlds[cur_idx:]
				new_init_decoder_state = new_init_decoder_state[:cur_idx] + [decoder_state] + new_init_decoder_state[cur_idx:]
				new_output_grids = new_output_grids[:cur_idx] + [output_grids[idx]] + new_output_grids[cur_idx:]
				new_output_karel_worlds = new_output_karel_worlds[:cur_idx] + [output_karel_worlds[idx]] + new_output_karel_worlds[cur_idx:]
				new_cum_prob = new_cum_prob[:cur_idx] + [cur_cum_prob] + new_cum_prob[cur_idx:]
				new_prog_prefix = new_prog_prefix[:cur_idx] + [prog_prefix[idx] + [prog_token]] + new_prog_prefix[cur_idx:]

				if len(new_cum_prob) > self.internal_beam_size:
					new_syntax_checker_state = new_syntax_checker_state[:self.internal_beam_size]
					new_init_prob = new_init_prob[:self.internal_beam_size]
					new_init_input_grids = new_init_input_grids[:self.internal_beam_size]
					new_init_karel_worlds = new_init_karel_worlds[:self.internal_beam_size]
					new_init_decoder_state = new_init_decoder_state[:self.internal_beam_size]
					new_output_grids = new_output_grids[:self.internal_beam_size]
					new_output_karel_worlds = new_output_karel_worlds[:self.internal_beam_size]
					new_cum_prob = new_cum_prob[:self.internal_beam_size]
					new_prog_prefix = new_prog_prefix[:self.internal_beam_size]

		self.search(new_syntax_checker_state, new_init_prob, new_init_input_grids, new_init_karel_worlds, new_init_decoder_state, new_output_grids, new_output_karel_worlds, current_depth - 1, new_cum_prob, new_prog_prefix)

	def run_prog_gen(self, train_input_grids, train_output_grids, val_input_grids, val_output_grids, progs, init_feed_previous=False):
		volatile_flag = init_feed_previous
		feed_previous = init_feed_previous
		batch_size = train_input_grids.size()[0]
		init_h = Variable(torch.zeros(self.num_RNN_layers, batch_size, self.hidden_size))
		init_c = Variable(torch.zeros(self.num_RNN_layers, batch_size, self.hidden_size))
		if self.cuda_flag:
			init_h = init_h.cuda()
			init_c = init_c.cuda()
		decoder_state = (init_h, init_c)

		predictions_logit = []
		predictions_prog = []
		prog_tokens = progs[:,0]
		if self.eval_mode == 'search':
			prog_tokens = data_utils.np_to_tensor(prog_tokens, "int", self.cuda_flag)
		input_grids = train_input_grids
		output_grids = train_output_grids
		syntax_checker_states = []
		for _ in range(batch_size):
			syntax_checker_states.append(self.simulator.syntax_checker.get_init_syntax_state())

		karel_worlds = []
		np_input_grids = input_grids.data.cpu().numpy()
		for idx in range(batch_size):
			cur_karel_worlds = []
			for io_idx in range(self.io_count):
				cur_karel_worlds.append(karel_simulator.KarelExecutionState(np_input_grids[idx][io_idx]))
			karel_worlds.append(cur_karel_worlds)

		output_karel_worlds = []
		np_output_grids = output_grids.data.cpu().numpy()
		for idx in range(batch_size):
			cur_karel_worlds = []
			for io_idx in range(self.io_count):
				cur_karel_worlds.append(karel_simulator.KarelExecutionState(np_output_grids[idx][io_idx]))
			output_karel_worlds.append(cur_karel_worlds)

		input_grids_seq = []
		input_grids_seq.append(input_grids)
		for time_step in range(1, self.max_prog_len):
			if self.baseline:
				decoder_output, decoder_state = self.model(input_grids_seq[0], output_grids, prog_tokens, decoder_state)
			else:
				decoder_output, decoder_state = self.model(input_grids, output_grids, prog_tokens, decoder_state)
			predictions_logit.append(decoder_output)
			if feed_previous:
				probs = np.exp(decoder_output.data.cpu().numpy())
				probs = probs + self.sample_bias
			else:
				probs = progs[:, time_step].data.cpu().numpy()

			if init_feed_previous and self.eval_mode == 'search':
				predicted_progs = []
				satisfied = []
				for idx in range(batch_size):
					self.best_predicted_progs = []
					self.best_prob = []
					self.best_vf = []
					self.min_prob_limit = 0.0
					self.search([syntax_checker_states[idx]], [probs[idx]], [input_grids[idx].unsqueeze(0)], [karel_worlds[idx]], [(decoder_state[0][:, idx, :].unsqueeze(0), decoder_state[1][:, idx, :].unsqueeze(0))], [output_grids[idx].unsqueeze(0)], [output_karel_worlds[idx]], self.search_depth, [1.0], [[]], progs[idx][1:])
					predicted_progs.append(list(self.best_predicted_progs))
					satisfied.append(len(self.best_vf) > 0 and self.best_vf[0] == 1.0)
					print(idx, self.best_vf[0])
				return predicted_progs, satisfied
			else:
				prog_tokens, input_grids, karel_worlds, syntax_checker_states = self.simulator.simulate(syntax_checker_states, probs, karel_worlds, feed_previous=feed_previous, volatile_flag=volatile_flag)
			if feed_previous is False:
				prog_tokens = progs[:, time_step]
			input_grids_seq.append(input_grids)
			predictions_prog.append(prog_tokens.data.cpu().numpy())

		return predictions_logit, predictions_prog

	def preprocess_data(self, batch_data):
		train_input_grids = []
		train_output_grids = []
		val_input_grids = []
		val_output_grids = []
		progs = []
		batch_size = len(batch_data)
		for item in batch_data:
			train_ios, val_io, prog, train_traces, val_trace = item
			cur_prog = [data_utils.GO_ID] + prog + [data_utils.PAD_ID] * (self.max_prog_len - 1 - len(prog))
			cur_input_grids = []
			cur_output_grids = []
			for io in train_ios:
				inp, outp = io
				cur_input_grids.append(inp)
				cur_output_grids.append(outp)
			train_input_grids.append(cur_input_grids)
			train_output_grids.append(cur_output_grids)
			inp, outp = val_io[0]
			val_input_grids.append(inp)
			val_output_grids.append(outp)
			progs.append(cur_prog)
		train_input_grids = np.array(train_input_grids)
		train_output_grids = np.array(train_output_grids)
		val_input_grids = np.array(val_input_grids)
		val_output_grids = np.array(val_output_grids)
		progs = np.array(progs)
		if self.eval_mode != 'search':
			progs = data_utils.np_to_tensor(progs, "int", self.cuda_flag)
		train_input_grids = data_utils.np_to_tensor(train_input_grids, "float", self.cuda_flag)
		train_output_grids = data_utils.np_to_tensor(train_output_grids, "float", self.cuda_flag)
		val_input_grids = data_utils.np_to_tensor(val_input_grids, "float", self.cuda_flag)
		val_output_grids = data_utils.np_to_tensor(val_output_grids, "float", self.cuda_flag)
		return train_input_grids, train_output_grids, val_input_grids, val_output_grids, progs

	def compute_loss(self, batch_data, predictions_logit):
		train_input_grids, train_output_grids, val_input_grids, val_output_grids, progs = self.preprocess_data(batch_data)
		max_prog_len = 0
		for data_item in batch_data:
			train_ios, val_io, prog, train_traces, val_trace = data_item
			max_prog_len = max(max_prog_len, len(prog))
		total_loss = None
		for time_step in range(max_prog_len):
			pred = predictions_logit[time_step]
			target = progs[:, time_step + 1]
			loss = self.model.compute_loss(pred, target)
			if total_loss is None:
				total_loss = loss
			else:
				total_loss += loss
		total_loss /= max_prog_len
		return total_loss


def create_model(args, prog_vocab, prog_vocab_list):
	if args.model_type == 'karel':
		model = KarelModel(args, prog_vocab, prog_vocab_list)
	else:
		print('model type undefined')
	if model.cuda_flag:
		model = model.cuda()
	model_supervisor = supervisor.Supervisor(model, args, prog_vocab, prog_vocab_list)
	if args.load_model:
		model_supervisor.load_pretrained(args.load_model)
	else:
		print('Created model with fresh parameters.')
		model_supervisor.model.init_weights(args.param_init)
	return model_supervisor
