import numpy as np
import operator
import random
import time
import copy
from ..data_utils import data_utils

STATE_MANDATORY_NEXT = 0
STATE_ACT_NEXT = 1
STATE_CONDITIONAL_NEXT = 2
STATE_REPEAT_NEXT = 3
STATE_POSTCOND_OPEN_PAREN = 4

# Indiator array of size 16 x height x width (height, width <= 18)
# 1st axis:
#   0: Hero facing North
#   1: Hero facing East
#   2: Hero facing South
#   3: Hero facing West
#   4: Internal walls
#   5: Surrounding walls
#   6: 1 marker
#   7: 2 markers
#   8: 3 markers
#   9: 4 markers
#   10: 5 markers
#   11: 6 markers
#   12: 7 markers
#   13: 8 markers
#   14: 9 markers
# Borders of array have the surrounding walls bit set.

class KarelSyntaxVocab(object):
	def __init__(self, prog_vocab, prog_vocab_list):
		self.start_token = prog_vocab['<s>']
		self.def_token = prog_vocab['DEF']
		self.run_token = prog_vocab['run']
		self.m_open_token = prog_vocab['m(']
		self.m_close_token = prog_vocab['m)']
		self.else_token = prog_vocab['ELSE']
		self.e_open_token = prog_vocab['e(']
		self.e_close_token = prog_vocab['e)']
		self.c_open_token = prog_vocab['c(']
		self.c_close_token = prog_vocab['c)']
		self.i_open_token = prog_vocab['i(']
		self.i_close_token = prog_vocab['i)']
		self.while_token = prog_vocab['WHILE']
		self.w_open_token = prog_vocab['w(']
		self.w_close_token = prog_vocab['w)']
		self.repeat_token = prog_vocab['REPEAT']
		self.r_open_token = prog_vocab['r(']
		self.r_close_token = prog_vocab['r)']
		self.not_token = prog_vocab['not']

class KarelExecutionState(object):
	def __init__(self, grid):
		self.ticks = 0
		self.max_ticks = 200
		self.grid = grid.copy()
		self.grid = self.grid.astype(int)
		self.prog_buffer = []
		self.repeat_stack = []
		self.while_stack = []
		self.while_conditional_stack = []
		self.if_stack = []
		self.if_conditional_stack = []
		self.pending_conditional = 0
		self.executed_pos = 0
		xs, ys = np.where(grid[5])
		self.height, self.width = xs.max() - 1, ys.max() - 1
		hero_pos = list(zip(*np.where(np.any(grid[:4], axis=0))))
		if len(hero_pos) > 1:
			raise ValueError('Invalid state: too many heroes')
		self.heroX = hero_pos[0][0]
		self.heroY = hero_pos[0][1]
		direction, = np.where(np.any(grid[:4], axis=(1,2)))
		if len(direction) > 1:
			raise ValueError('Invalid state: too many hero directions')
		self.hero_dir = direction[0]
		self.blocks = self.grid[4]
		self.crash = False
		self.dir_vocab = {'North': 0, 'East': 1, 'South': 2, 'West': 3}
		self.max_markers = 10
		self.marker_delta = 5
		self.markers = np.zeros((self.height + 1, self.width + 1), dtype=int)
		for i in range(1, self.max_markers + 1):
			self.markers += self.grid[i + self.marker_delta, :self.height + 1, :self.width + 1] * i

	def update_prog_buffer(self, prog_token):
		self.prog_buffer.append(prog_token)
	def clear_prog_buffer(self):
		self.prog_buffer = []
		self.repeat_stack = []
		self.while_stack = []
		self.while_conditional_stack = []
		self.if_stack = []
		self.if_conditional_stack = []
		self.executed_pos = 0

	def __eq__(self, other):
		if self.heroX != other.heroX:
			return False
		if self.heroY != other.heroY:
			return False
		if self.hero_dir != other.hero_dir:
			return False
		# if self.crash != other.crash:
		# 	return False
		if np.array_equal(self.markers, other.markers):
			return True
		else:
			return False

	def __ne__(self, other):
		return not (self == other)

	def dis(self, other):
		res = abs(self.heroX - other.heroX) + abs(self.heroY - other.heroY) + (self.hero_dir != other.hero_dir)
		res += np.sum(np.abs(self.markers - other.markers))
		return res

	def isClear(self, x, y):
		if x <= 0 or y <= 0 or x > self.height or y > self.width:
			return False
		if self.blocks[x][y] > 0:
			return False
		return True
	
	def stepsExceeded(self):
		return self.ticks > self.max_ticks

	def move(self):
		if self.crash:
			return
		new_heroX = self.heroX
		new_heroY = self.heroY
		if self.hero_dir == self.dir_vocab['North']:
			new_heroX += 1
		elif self.hero_dir == self.dir_vocab['South']:
			new_heroX -= 1
		elif self.hero_dir == self.dir_vocab['East']:
			new_heroY += 1
		elif self.hero_dir == self.dir_vocab['West']:
			new_heroY -= 1
		if not self.isClear(new_heroX, new_heroY):
			return
			# self.crash = True
		if not self.crash:
			self.grid[self.hero_dir, self.heroX, self.heroY] = 0
			self.heroX = new_heroX
			self.heroY = new_heroY
			self.grid[self.hero_dir, self.heroX, self.heroY] = 1

	def turnLeft(self):
		if self.crash:
			return
		new_hero_dir = self.hero_dir
		if self.hero_dir == self.dir_vocab['North']:
			new_hero_dir = self.dir_vocab['West']
		elif self.hero_dir == self.dir_vocab['West']:
			new_hero_dir = self.dir_vocab['South']
		elif self.hero_dir == self.dir_vocab['South']:
			new_hero_dir = self.dir_vocab['East']
		elif self.hero_dir == self.dir_vocab['East']:
			new_hero_dir = self.dir_vocab['North']
		self.grid[self.hero_dir, self.heroX, self.heroY] = 0
		self.hero_dir = new_hero_dir
		self.grid[self.hero_dir, self.heroX, self.heroY] = 1

	def turnRight(self):
		if self.crash:
			return
		new_hero_dir = self.hero_dir
		if self.hero_dir == self.dir_vocab['North']:
			new_hero_dir = self.dir_vocab['East']
		elif self.hero_dir == self.dir_vocab['East']:
			new_hero_dir = self.dir_vocab['South']
		elif self.hero_dir == self.dir_vocab['South']:
			new_hero_dir = self.dir_vocab['West']
		elif self.hero_dir == self.dir_vocab['West']:
			new_hero_dir = self.dir_vocab['North']
		self.grid[self.hero_dir, self.heroX, self.heroY] = 0
		self.hero_dir = new_hero_dir
		self.grid[self.hero_dir, self.heroX, self.heroY] = 1

	def putMarker(self):
		if self.crash:
			return
		if self.markers[self.heroX][self.heroY] + 1 > self.max_markers:
			return
		self.markers[self.heroX][self.heroY] += 1
		# if self.markers[self.heroX][self.heroY] > self.max_markers:
		# 	return
		# 	self.crash = True
		if self.markers[self.heroX][self.heroY] > 1:
			self.grid[self.markers[self.heroX][self.heroY] - 1 + self.marker_delta, self.heroX, self.heroY] = 0
		self.grid[self.markers[self.heroX][self.heroY] + self.marker_delta, self.heroX, self.heroY] = 1

	def pickMarker(self):
		if self.crash:
			return
		if self.markers[self.heroX][self.heroY] == 0:
			return
			# self.crash = True
		else:
			self.markers[self.heroX][self.heroY] -= 1
			self.grid[self.markers[self.heroX][self.heroY] + 1 + self.marker_delta, self.heroX, self.heroY] = 0
			if self.markers[self.heroX][self.heroY] > 0:
				self.grid[self.markers[self.heroX][self.heroY] + self.marker_delta, self.heroX, self.heroY] = 1

	def markersPresent(self):
		if self.crash:
			return
		return self.markers[self.heroX][self.heroY] > 0

	def noMarkersPresent(self):
		if self.crash:
			return
		return self.markers[self.heroX][self.heroY] == 0

	def leftIsClear(self):
		if self.crash:
			return
		if self.hero_dir == self.dir_vocab['North']:
			return self.isClear(self.heroX, self.heroY - 1)
		elif self.hero_dir == self.dir_vocab['West']:
			return self.isClear(self.heroX - 1, self.heroY)
		elif self.hero_dir == self.dir_vocab['South']:
			return self.isClear(self.heroX, self.heroY + 1)
		elif self.hero_dir == self.dir_vocab['East']:
			return self.isClear(self.heroX + 1, self.heroY)

	def rightIsClear(self):
		if self.crash:
			return
		if self.hero_dir == self.dir_vocab['North']:
			return self.isClear(self.heroX, self.heroY + 1)
		elif self.hero_dir == self.dir_vocab['East']:
			return self.isClear(self.heroX - 1, self.heroY)
		elif self.hero_dir == self.dir_vocab['South']:
			return self.isClear(self.heroX, self.heroY - 1)
		elif self.hero_dir == self.dir_vocab['West']:
			return self.isClear(self.heroX + 1, self.heroY)

	def frontIsClear(self):
		if self.crash:
			return
		if self.hero_dir == self.dir_vocab['North']:
			return self.isClear(self.heroX + 1, self.heroY)
		elif self.hero_dir == self.dir_vocab['East']:
			return self.isClear(self.heroX, self.heroY + 1)
		elif self.hero_dir == self.dir_vocab['South']:
			return self.isClear(self.heroX - 1, self.heroY)
		elif self.hero_dir == self.dir_vocab['West']:
			return self.isClear(self.heroX, self.heroY - 1)

	def executeAction(self, actionName):
		action_func = getattr(self, actionName)
		action_func()
		self.ticks += 1

	def evalCond(self, condName):
		cond_func = getattr(self, condName)
		self.ticks += 1
		return cond_func()

class KarelSyntaxCheckerState(object):
	def __init__(self, state_type, mask):
		self.state_type = state_type
		self.need_else_stack = []
		self.to_close_stack = []
		self.mask = copy.deepcopy(mask)
		self.conditional_depth = 0
		self.tot_conditinonal = 0
		self.next_act_block = -1
	def push_closeparen_to_stack(self, close_paren):
		self.to_close_stack.append(close_paren)
	def pop_closeparen(self):
		return self.to_close_stack.pop()
	def update(self, state_type, mask):
		self.state_type = state_type
		self.mask = copy.deepcopy(mask)
	def push_conditional(self):
		self.conditional_depth += 1
	def pop_conditional(self):
		self.conditional_depth -= 1
	def push_needelse_stack(self, need_else):
		self.need_else_stack.append(need_else)
	def pop_needelse_stack(self):
		return self.need_else_stack.pop()
	def set_next_act_block(self, token):
		self.next_act_block = token

class KarelSyntaxChecker(object):
	def __init__(self, args, prog_vocab, prog_vocab_list):
		self.prog_vocab = prog_vocab
		self.prog_vocab_list = prog_vocab_list
		self.prog_vocab_size = args.prog_vocab_size

		self.open_paren_token = ['m(', 'c(', 'r(', 'w(', 'i(', 'e(']
		self.open_paren_hash = {}
		for tok in self.open_paren_token:
			self.open_paren_hash[prog_vocab[tok]] = 1

		self.close_paren_token = ['m)', 'c)', 'r)','w)', 'i)', 'e)']
		self.close_paren_hash = {}
		for tok in self.close_paren_token:
			self.close_paren_hash[prog_vocab[tok]] = 1

		self.if_need_else = {prog_vocab['IF']: False, prog_vocab['IFELSE']: True}

		self.op2cl = {}
		for op, cl in zip(self.open_paren_token, self.close_paren_token):
			self.op2cl[prog_vocab[op]] = prog_vocab[cl]

		self.flow_leads = ['REPEAT', 'WHILE', 'IF', 'IFELSE']
		self.flow_lead_hash = {}
		for tok in self.flow_leads:
			self.flow_lead_hash[prog_vocab[tok]] = 1

		self.conditional_flow_leads = ['WHILE', 'IF', 'IFELSE']
		self.conditional_flow_lead_hash = {}
		for tok in self.conditional_flow_leads:
			self.conditional_flow_lead_hash[prog_vocab[tok]] = 1

		self.actions = [
		'move',
		'turnLeft',
		'turnRight',
		'pickMarker',
		'putMarker'
		]
		self.action_hash = {}
		for act in self.actions:
			self.action_hash[prog_vocab[act]] = 1

		self.conditionals = [
		'markersPresent',
		'noMarkersPresent',
		'leftIsClear',
		'rightIsClear',
		'frontIsClear'
		]
		self.conditional_hash = {}
		for cond in self.conditionals:
			self.conditional_hash[prog_vocab[cond]] = 1

		self.postcond_open_paren = ['i(', 'w(']
		self.postcond_open_paren_hash = {}
		for tok in self.postcond_open_paren:
			self.postcond_open_paren_hash[prog_vocab[tok]] = 1

		self.repeat_token_hash = {}
		for idx, tok in enumerate(self.prog_vocab_list):
			if tok.startswith('R='):
				self.repeat_token_hash[idx] = 1

		self.syntax_vocab = KarelSyntaxVocab(prog_vocab, prog_vocab_list)

		self.possible_mandatories = ['<s>', 'DEF', 'run', 'c)', 'ELSE', data_utils._PAD] + self.open_paren_token
		self.mandatories_masks = {}
		for tok in self.possible_mandatories:
			mask = np.zeros(self.prog_vocab_size)
			mask[prog_vocab[tok]] = 1.0
			self.mandatories_masks[prog_vocab[tok]] = mask

		self.act_masks = {}
		act_mask = np.zeros(self.prog_vocab_size)
		for act in self.action_hash:
			act_mask[act] = 1.0
		for tok in self.flow_lead_hash:
			act_mask[tok] = 1.0
		for tok in self.close_paren_hash:
			current_act_mask = act_mask.copy()
			current_act_mask[tok] = 1.0
			self.act_masks[tok] = current_act_mask

		self.repeat_mask = np.zeros(self.prog_vocab_size)
		for tok in self.repeat_token_hash:
			self.repeat_mask[tok] = 1.0

		self.conditional_mask = np.zeros(self.prog_vocab_size)
		for tok in self.conditional_hash:
			self.conditional_mask[tok] = 1.0
		self.conditional_mask[self.syntax_vocab.not_token] = 1.0

		self.postcond_open_paren_masks = {}
		for tok in self.postcond_open_paren_hash:
			mask = np.zeros(self.prog_vocab_size)
			mask[tok] = 1.0
			self.postcond_open_paren_masks[tok] = mask

	def get_init_syntax_state(self):
		return KarelSyntaxCheckerState(STATE_MANDATORY_NEXT, self.mandatories_masks[self.syntax_vocab.def_token])

	def update_syntax_state(self, state, prog_token):
		if prog_token in self.open_paren_hash:
			close_token = self.op2cl[prog_token]
			state.push_closeparen_to_stack(close_token)
			for tok in self.close_paren_hash:
				if tok == close_token:
					state.mask[tok] = 1.0
				else:
					state.mask[tok] = 0.0
		if prog_token in self.close_paren_hash:
			close_token = state.pop_closeparen()
			for tok in self.close_paren_hash:
				state.mask[tok] = 0.0
			if len(state.to_close_stack) > 0:
				state.mask[state.to_close_stack[-1]] = 1.0

		if state.state_type == STATE_MANDATORY_NEXT:
			if prog_token == self.syntax_vocab.start_token:
				state.update(STATE_MANDATORY_NEXT, self.mandatories_masks[self.syntax_vocab.def_token])
			elif prog_token == self.syntax_vocab.def_token:
				state.update(STATE_MANDATORY_NEXT, self.mandatories_masks[self.syntax_vocab.run_token])
			elif prog_token == self.syntax_vocab.run_token:
				state.update(STATE_MANDATORY_NEXT, self.mandatories_masks[self.syntax_vocab.m_open_token])
			elif prog_token == self.syntax_vocab.else_token:
				state.update(STATE_MANDATORY_NEXT, self.mandatories_masks[self.syntax_vocab.e_open_token])
			elif prog_token in self.open_paren_hash:
				if prog_token == self.syntax_vocab.c_open_token:
					state.update(STATE_CONDITIONAL_NEXT, self.conditional_mask)
					state.push_conditional()
				else:
					close_token = self.op2cl[prog_token]
					state.update(STATE_ACT_NEXT, self.act_masks[close_token])
			elif prog_token == self.syntax_vocab.c_close_token:
				state.pop_conditional()
				if state.conditional_depth == 0:
					state.update(STATE_POSTCOND_OPEN_PAREN, self.postcond_open_paren_masks[state.next_act_block])
				else:
					state.update(STATE_MANDATORY_NEXT, self.mandatories_masks[self.syntax_vocab.c_close_token])
			elif prog_token == data_utils.PAD_ID:
				pass
			else:
				raise NotImplementedError
		elif state.state_type == STATE_ACT_NEXT:
			if prog_token in self.conditional_flow_lead_hash:
				state.update(STATE_MANDATORY_NEXT, self.mandatories_masks[self.syntax_vocab.c_open_token])
				if prog_token in self.if_need_else:
					state.push_needelse_stack(self.if_need_else[prog_token])
					state.set_next_act_block(self.syntax_vocab.i_open_token)
				elif prog_token == self.syntax_vocab.while_token:
					state.set_next_act_block(self.syntax_vocab.w_open_token)
			elif prog_token == self.syntax_vocab.repeat_token:
				state.update(STATE_REPEAT_NEXT, self.repeat_mask)
			elif prog_token in self.action_hash:
				pass
			elif prog_token in self.close_paren_hash:
				if prog_token == self.syntax_vocab.i_close_token:
					need_else = state.pop_needelse_stack()
					if need_else:
						state.update(STATE_MANDATORY_NEXT, self.mandatories_masks[self.syntax_vocab.else_token])
					else:
						state.update(STATE_ACT_NEXT, self.act_masks[state.to_close_stack[-1]])
				elif prog_token == self.syntax_vocab.m_close_token:
					state.update(STATE_MANDATORY_NEXT, self.mandatories_masks[data_utils.PAD_ID])
				else:
					state.update(STATE_ACT_NEXT, self.act_masks[state.to_close_stack[-1]])
			else:
				raise NotImplementedError
		elif state.state_type == STATE_REPEAT_NEXT:
			state.update(STATE_MANDATORY_NEXT, self.mandatories_masks[self.syntax_vocab.r_open_token])
		elif state.state_type == STATE_CONDITIONAL_NEXT:
			if prog_token == self.syntax_vocab.not_token:
				state.update(STATE_MANDATORY_NEXT, self.mandatories_masks[self.syntax_vocab.c_open_token])
			elif prog_token in self.conditional_hash:
				state.update(STATE_MANDATORY_NEXT, self.mandatories_masks[self.syntax_vocab.c_close_token])
			else:
				raise NotImplementedError
		elif state.state_type == STATE_POSTCOND_OPEN_PAREN:
			state.update(STATE_ACT_NEXT, self.act_masks[state.to_close_stack[-1]])
		else:
			raise NotImplementedError

		return state


class KarelSimulator(object):
	def __init__(self, args, prog_vocab, prog_vocab_list):
		self.syntax_checker = KarelSyntaxChecker(args, prog_vocab, prog_vocab_list)
		self.prog_vocab = prog_vocab
		self.prog_vocab_size = len(prog_vocab)
		self.prog_vocab_list = prog_vocab_list
		self.io_count = args.io_count
		self.actions = [
		'move',
		'turnLeft',
		'turnRight',
		'pickMarker',
		'putMarker'
		]
		self.action_hash = {}
		for act in self.actions:
			self.action_hash[prog_vocab[act]] = 1

		self.conditionals = [
		'markersPresent',
		'noMarkersPresent',
		'leftIsClear',
		'rightIsClear',
		'frontIsClear'
		]
		self.conditional_hash = {}
		for cond in self.conditionals:
			self.conditional_hash[prog_vocab[cond]] = 1
		self.cuda_flag = args.cuda
		self.eval_mode = args.eval_mode
		self.block_stack = []
		self.block_pos_stack = []
		self.sample_bias = args.sample_bias

		self.prog_vocab_reverse = dict(zip(prog_vocab.values(), prog_vocab.keys()))

	def eval_cond_seq(self, karel_world, prog_buffer):
		cond = False
		if prog_buffer[0] == self.syntax_checker.syntax_vocab.not_token:
			cond = self.eval_cond_seq(karel_world, prog_buffer[2:-1])
			cond = not cond
		else:
			cond = karel_world.evalCond(self.prog_vocab_list[prog_buffer[0]])
		return cond

	def execute_block(self, karel_world, repeat_delta=0):
		executed = False
		if len(self.block_stack[-1]) == 0:
			executed = True
			return karel_world, executed
		if karel_world.stepsExceeded():
			return karel_world, True
		while self.block_pos_stack[-1] < len(self.block_stack[-1]):
			prog_token = self.block_stack[-1][self.block_pos_stack[-1]]
			# print(self.prog_vocab_reverse.get(prog_token))
			if prog_token in self.action_hash:
				karel_world.executeAction(self.prog_vocab_list[prog_token])
				# print(self.prog_vocab_list[prog_token], karel_world.ticks)
				self.block_pos_stack[-1] += 1
				executed = True
				if karel_world.stepsExceeded():
					break
			elif prog_token == self.syntax_checker.syntax_vocab.repeat_token:
				times_token = self.prog_vocab_list[self.block_stack[-1][self.block_pos_stack[-1] + 1]]
				times = int(times_token[2:])
				if self.block_pos_stack[-1] == 0:
					times = times + repeat_delta
				repeat_ed = self.block_pos_stack[-1] + 3
				repeat_depth = 1
				repeat_prog_buffer = []
				while repeat_ed < len(self.block_stack[-1]):
					if self.block_stack[-1][repeat_ed] == self.syntax_checker.syntax_vocab.r_open_token:
						repeat_depth += 1
					elif self.block_stack[-1][repeat_ed] == self.syntax_checker.syntax_vocab.r_close_token:
						repeat_depth -= 1
					if repeat_depth == 0:
						break
					repeat_prog_buffer.append(self.block_stack[-1][repeat_ed])
					repeat_ed += 1
				self.block_stack.append(np.array(repeat_prog_buffer))
				self.block_pos_stack.append(0)
				for _ in range(times):
					self.block_pos_stack[-1] = 0
					karel_world.ticks += 1
					# print("repeat cond", karel_world.ticks)
					if karel_world.stepsExceeded():
						break
					karel_world, executed = self.execute_block(karel_world)
					# print("repeat", karel_world.ticks)
					if karel_world.stepsExceeded():
						break
				self.block_stack.pop()
				self.block_pos_stack.pop()
				self.block_pos_stack[-1] = repeat_ed + 1
				executed = True
				if karel_world.stepsExceeded():
					break
			elif prog_token == self.syntax_checker.syntax_vocab.while_token:
				cond_st = self.block_pos_stack[-1] + 1
				cond_ed = cond_st + 1
				cond_depth = 1
				cond_prog_buffer = []
				while cond_ed < len(self.block_stack[-1]):
					if self.block_stack[-1][cond_ed] == self.syntax_checker.syntax_vocab.c_open_token:
						cond_depth += 1
					elif self.block_stack[-1][cond_ed] == self.syntax_checker.syntax_vocab.c_close_token:
						cond_depth -= 1
					if cond_depth == 0:
						break
					cond_prog_buffer.append(self.block_stack[-1][cond_ed])
					cond_ed += 1
				cond_value = self.eval_cond_seq(karel_world, np.array(cond_prog_buffer))
				# print("while cond", karel_world.ticks)
				if karel_world.stepsExceeded():
					executed = True
					break
				while_st = cond_ed  + 1
				while_ed = while_st + 1
				while_depth = 1
				while_prog_buffer = []
				while while_ed < len(self.block_stack[-1]):
					if self.block_stack[-1][while_ed] == self.syntax_checker.syntax_vocab.w_open_token:
							while_depth += 1
					elif self.block_stack[-1][while_ed] == self.syntax_checker.syntax_vocab.w_close_token:
						while_depth -= 1
					if while_depth == 0:
						break
					while_prog_buffer.append(self.block_stack[-1][while_ed])
					while_ed += 1
				self.block_stack.append(np.array(while_prog_buffer))
				self.block_pos_stack.append(0)
				
				while cond_value == True:
					self.block_pos_stack[-1] = 0
					karel_world, executed = self.execute_block(karel_world)
					# print("while", karel_world.ticks)
					if karel_world.stepsExceeded():
						break
					cond_value = self.eval_cond_seq(karel_world, np.array(cond_prog_buffer))
					# print("while cond", karel_world.ticks)
					if karel_world.stepsExceeded():
						break
					if cond_value == False:
						break
				self.block_stack.pop()
				self.block_pos_stack.pop()
				self.block_pos_stack[-1] = while_ed + 1
				executed = True
				if karel_world.stepsExceeded():
					break
			elif prog_token in self.syntax_checker.if_need_else:
				cond_st = self.block_pos_stack[-1] + 1
				cond_ed = cond_st + 1
				cond_depth = 1
				cond_prog_buffer = []
				while cond_ed < len(self.block_stack[-1]):
					if self.block_stack[-1][cond_ed] == self.syntax_checker.syntax_vocab.c_open_token:
						cond_depth += 1
					elif self.block_stack[-1][cond_ed] == self.syntax_checker.syntax_vocab.c_close_token:
						cond_depth -= 1
					if cond_depth == 0:
						break
					cond_prog_buffer.append(self.block_stack[-1][cond_ed])
					cond_ed += 1
				cond_value = self.eval_cond_seq(karel_world, np.array(cond_prog_buffer))
				# print("if cond", karel_world.ticks, cond_value)
				if karel_world.stepsExceeded():
					executed = True
					break
				if_st = cond_ed + 1
				if_ed = if_st + 1
				if_depth = 1
				if_prog_buffer = []
				while if_ed < len(self.block_stack[-1]):
					if self.block_stack[-1][if_ed] == self.syntax_checker.syntax_vocab.i_open_token:
						if_depth += 1
					elif self.block_stack[-1][if_ed] == self.syntax_checker.syntax_vocab.i_close_token:
						if_depth -= 1
					if if_depth == 0:
						break
					if_prog_buffer.append(self.block_stack[-1][if_ed])
					if_ed += 1
				else_st = None
				if if_ed + 1 < len(self.block_stack[-1]) and self.block_stack[-1][if_ed + 1] == self.syntax_checker.syntax_vocab.else_token:
					else_st = if_ed + 2
					else_ed = else_st + 1
					else_depth = 1
					else_prog_buffer = []
					while else_ed < len(self.block_stack[-1]):
						if self.block_stack[-1][else_ed] == self.syntax_checker.syntax_vocab.e_open_token:
							else_depth += 1
						elif self.block_stack[-1][else_ed] == self.syntax_checker.syntax_vocab.e_close_token:
							else_depth -= 1
						if else_depth == 0:
							break
						else_prog_buffer.append(self.block_stack[-1][else_ed])
						else_ed += 1
				if cond_value == True:
					self.block_stack.append(np.array(if_prog_buffer))
					self.block_pos_stack.append(0)
					karel_world, executed = self.execute_block(karel_world)
					# print("if", karel_world.ticks)
					executed = True
					if karel_world.stepsExceeded():
						break
					self.block_stack.pop()
					self.block_pos_stack.pop()
				elif else_st is not None:
					self.block_stack.append(np.array(else_prog_buffer))
					self.block_pos_stack.append(0)
					# karel_world.ticks -= 1
					karel_world, executed = self.execute_block(karel_world)
					# print("else", karel_world.ticks)
					executed = True
					if karel_world.stepsExceeded():
						break
					self.block_stack.pop()
					self.block_pos_stack.pop()
				elif self.syntax_checker.if_need_else[prog_token]:
					executed = False
				else:
					executed = True
				if else_st is not None:
					self.block_pos_stack[-1] = else_ed + 1
				else:
					self.block_pos_stack[-1] = if_ed + 1
			else:
				print([self.prog_vocab_list[tok] for tok in self.block_stack[-1]])
				raise NotImplementedError
		return karel_world, executed				

	def execute(self, init_karel_world, init_syntax_checker_state, prog_token):
		syntax_checker_state = copy.deepcopy(init_syntax_checker_state)
		syntax_checker_state = self.syntax_checker.update_syntax_state(syntax_checker_state, prog_token)
		if prog_token == self.syntax_checker.syntax_vocab.def_token \
		or prog_token == self.syntax_checker.syntax_vocab.run_token \
		or prog_token == self.syntax_checker.syntax_vocab.m_open_token \
		or prog_token == self.syntax_checker.syntax_vocab.m_close_token \
		or prog_token == data_utils.PAD_ID:
			return init_karel_world, syntax_checker_state
		karel_world = copy.deepcopy(init_karel_world)
		if karel_world.stepsExceeded():
			return karel_world, syntax_checker_state
		
		# print("token:", self.prog_vocab_reverse[prog_token])
		
		if prog_token == self.syntax_checker.syntax_vocab.r_open_token and \
			(len(karel_world.while_conditional_stack) == 0 or karel_world.while_conditional_stack[-1]) \
			and (len(karel_world.if_conditional_stack) == 0 or karel_world.if_conditional_stack[-1]) \
			and karel_world.pending_conditional == 0:
			# and len(karel_world.repeat_stack) == 1:
				# len(karel_world.while_stack) == 0 and len(karel_world.if_stack) == 0 and \
			karel_world.ticks += 1
		if (not (prog_token in self.action_hash)) and (not (prog_token in self.syntax_checker.close_paren_hash)):
			karel_world.update_prog_buffer(prog_token)
			if prog_token == self.prog_vocab['IFELSE']:
				karel_world.pending_conditional += 1
			elif karel_world.pending_conditional == 0 and prog_token == self.syntax_checker.syntax_vocab.repeat_token:
				karel_world.repeat_stack.append(len(karel_world.prog_buffer) - 1)
			elif karel_world.pending_conditional == 0 and prog_token == self.syntax_checker.syntax_vocab.while_token:
				karel_world.while_stack.append(len(karel_world.prog_buffer) - 1)
			elif karel_world.pending_conditional == 0 and prog_token == self.prog_vocab['IF']:
				karel_world.if_stack.append(len(karel_world.prog_buffer) - 1)			
			if karel_world.pending_conditional == 0 and syntax_checker_state.conditional_depth == 0:
				karel_world.executed_pos = len(karel_world.prog_buffer)
			return karel_world, syntax_checker_state

		if prog_token == self.syntax_checker.syntax_vocab.e_close_token:
			karel_world.pending_conditional -= 1

		if len(syntax_checker_state.to_close_stack) > 1:
			karel_world.update_prog_buffer(prog_token)
			if karel_world.pending_conditional > 0:
				return karel_world, syntax_checker_state
		if prog_token in self.action_hash:
			if (len(karel_world.while_conditional_stack) == 0 or karel_world.while_conditional_stack[-1]) \
			and (len(karel_world.if_conditional_stack) == 0 or karel_world.if_conditional_stack[-1]):
				karel_world.executeAction(self.prog_vocab_list[prog_token])
				# print("out action", karel_world.ticks)
			karel_world.executed_pos = len(karel_world.prog_buffer)
			return karel_world, syntax_checker_state
		if len(syntax_checker_state.to_close_stack) <= 1:
			karel_world.update_prog_buffer(prog_token)
		if prog_token == self.syntax_checker.syntax_vocab.c_close_token:
			if syntax_checker_state.conditional_depth == 0:
				cond_st = len(karel_world.prog_buffer) - 2
				cond_depth = 1
				while cond_st > 0:
					if karel_world.prog_buffer[cond_st] == self.syntax_checker.syntax_vocab.c_open_token:
						cond_depth -= 1
					elif karel_world.prog_buffer[cond_st] == self.syntax_checker.syntax_vocab.c_close_token:
						cond_depth += 1
					if cond_depth == 0:
						break
					cond_st -= 1
				cond_value = self.eval_cond_seq(karel_world, np.array(karel_world.prog_buffer[cond_st + 1:-1]))
				# print("cond", karel_world.ticks, cond_value)
				if (len(karel_world.while_conditional_stack) > 0 and karel_world.while_conditional_stack[-1] == False) \
					or (len(karel_world.if_conditional_stack) > 0 and karel_world.if_conditional_stack[-1] == False):
						karel_world.ticks -= 1
				if cond_st > 0 and karel_world.prog_buffer[cond_st - 1] == self.syntax_checker.syntax_vocab.while_token:
					if len(karel_world.while_conditional_stack) > 0 and karel_world.while_conditional_stack[-1] == False:
						karel_world.while_conditional_stack.append(False)
					else:
						karel_world.while_conditional_stack.append(cond_value)
				elif cond_st > 0 and karel_world.prog_buffer[cond_st - 1] == self.prog_vocab['IF']:
					if len(karel_world.if_conditional_stack) > 0 and karel_world.if_conditional_stack[-1] == False:
						karel_world.if_conditional_stack.append(False)
					else:
						karel_world.if_conditional_stack.append(cond_value)
					karel_world.executed_pos = len(karel_world.prog_buffer)
				if cond_st > 0 and karel_world.prog_buffer[cond_st - 1] == self.prog_vocab['IFELSE']:
					if cond_value == True:
						karel_world.ticks -= 1
					else:
						karel_world.ticks -= 2
			return karel_world, syntax_checker_state
		if prog_token == self.syntax_checker.syntax_vocab.e_close_token:
			if karel_world.prog_buffer[karel_world.executed_pos] == self.syntax_checker.syntax_vocab.else_token:
				if karel_world.executed_pos == 0:
					karel_world.clear_prog_buffer()
				else:
					karel_world.executed_pos = len(karel_world.prog_buffer)
				return karel_world, syntax_checker_state
		if prog_token == self.syntax_checker.syntax_vocab.r_close_token:
			repeat_st = karel_world.repeat_stack.pop()
			if (len(karel_world.while_conditional_stack) == 0 or karel_world.while_conditional_stack[-1]) \
			and (len(karel_world.if_conditional_stack) == 0 or karel_world.if_conditional_stack[-1]):
				self.block_stack = []
				self.block_stack.append(np.array(karel_world.prog_buffer[repeat_st:]))
				self.block_pos_stack = []
				self.block_pos_stack.append(0)
				if not karel_world.crash:
					karel_world, executed = self.execute_block(karel_world, -1)
				else:
					executed = True
				self.block_stack = []
				self.block_pos_stack = []
			karel_world.executed_pos = len(karel_world.prog_buffer)
			if len(syntax_checker_state.to_close_stack) <= 1:
				karel_world.clear_prog_buffer()
			return karel_world, syntax_checker_state
		elif prog_token == self.syntax_checker.syntax_vocab.w_close_token:
			while_st = karel_world.while_stack.pop()
			while_cond = karel_world.while_conditional_stack.pop()
			if while_cond and (len(karel_world.if_conditional_stack) == 0 or karel_world.if_conditional_stack[-1]):
				self.block_stack = []
				self.block_stack.append(np.array(karel_world.prog_buffer[while_st:]))
				self.block_pos_stack = []
				self.block_pos_stack.append(0)
				if not karel_world.crash:
					karel_world, executed = self.execute_block(karel_world)
				else:
					executed = True
				self.block_stack = []
				self.block_pos_stack = []
			karel_world.executed_pos = len(karel_world.prog_buffer)
			if len(syntax_checker_state.to_close_stack) <= 1:
				karel_world.clear_prog_buffer()
			return karel_world, syntax_checker_state
		elif prog_token == self.syntax_checker.syntax_vocab.i_close_token:
			if karel_world.pending_conditional == 0:
				karel_world.if_stack.pop()
				karel_world.if_conditional_stack.pop()
				karel_world.executed_pos = len(karel_world.prog_buffer)
				if len(syntax_checker_state.to_close_stack) <= 1:
					karel_world.clear_prog_buffer()
				return karel_world, syntax_checker_state

		if (len(karel_world.while_conditional_stack) == 0 or karel_world.while_conditional_stack[-1]) \
		and (len(karel_world.if_conditional_stack) == 0 or karel_world.if_conditional_stack[-1]):
			self.block_stack = []
			self.block_stack.append(np.array(karel_world.prog_buffer[karel_world.executed_pos:]))
			self.block_pos_stack = []
			self.block_pos_stack.append(0)
			if not karel_world.crash:
				karel_world, executed = self.execute_block(karel_world)
			else:
				executed = True
			self.block_stack = []
			self.block_pos_stack = []
			if executed:
				karel_world.executed_pos = len(karel_world.prog_buffer)
		else:
			executed = True
		if executed:
			if len(syntax_checker_state.to_close_stack) <= 1:
				karel_world.clear_prog_buffer()
		return karel_world, syntax_checker_state


	def simulate(self, syntax_checker_states, init_predictions, init_karel_worlds, feed_previous, volatile_flag=False):
		batch_size = len(init_predictions)
		if feed_previous == False:
			predictions = init_predictions
		else:
			probs = []
			for i in range(batch_size):
				current_prob = init_predictions[i] * syntax_checker_states[i].mask
				probs.append(current_prob / np.sum(current_prob))
			probs = np.array(probs)
			candidate_predictions = np.argsort(probs, axis=1)
			predictions = []
			for idx in range(batch_size):
				if self.eval_mode == 'random':
					thres = np.random.choice(self.prog_vocab_size, 1, p = probs[idx])[0]
					thres = probs[idx][thres]
				else:
					thres = 1.0
				candidate_token = None
				for prog_token in candidate_predictions[idx][::-1]:
					if probs[idx][prog_token] < self.sample_bias:
						break
					if probs[idx][prog_token] > thres:
						continue
					crash = False
					for io_idx in range(self.io_count):
						cur_karel_world, cur_syntax_checker_state = self.execute(init_karel_worlds[idx][io_idx], syntax_checker_states[idx], prog_token)
						if cur_karel_world.crash:
							crash = True
							break
					if not crash:
						candidate_token = prog_token
						break

				if candidate_token is None:
					for prog_token in candidate_predictions[idx][::-1]:
						crash = False
						for io_idx in range(self.io_count):
							cur_karel_world, cur_syntax_checker_state = self.execute(init_karel_worlds[idx][io_idx], syntax_checker_states[idx], prog_token)
							if cur_karel_world.crash:
								crash = True
								break
						if not crash:
							candidate_token = prog_token
							break
				predictions.append(candidate_token)
		predictions = np.array(predictions)
		karel_worlds = []
		input_grids = []
		for idx in range(batch_size):
			cur_karel_worlds = []
			cur_input_grids = []
			for io_idx in range(self.io_count):
				cur_karel_world, cur_syntax_checker_state = self.execute(init_karel_worlds[idx][io_idx], syntax_checker_states[idx], predictions[idx])
				cur_karel_worlds.append(cur_karel_world)
				cur_input_grids.append(cur_karel_world.grid)
			karel_worlds.append(cur_karel_worlds)
			input_grids.append(cur_input_grids)
			syntax_checker_states[idx] = cur_syntax_checker_state

		input_grids = np.array(input_grids)
		if feed_previous is True:
			predictions = data_utils.np_to_tensor(predictions, 'int', cuda_flag=self.cuda_flag, volatile_flag=volatile_flag)
			input_grids = data_utils.np_to_tensor(input_grids, 'float', cuda_flag=self.cuda_flag, volatile_flag=volatile_flag)
		else:
			input_grids = data_utils.np_to_tensor(input_grids, 'float', cuda_flag=self.cuda_flag)
		return predictions, input_grids, karel_worlds, syntax_checker_states

