import numpy as np
import argparse
import sys
import os
import torch
import re
import json
import pickle
from torch.nn.utils import clip_grad_norm
from ..data_utils import data_utils
from ..simulators import karel_simulator

CKPT_PATTERN = re.compile('^ckpt-(\d+)$')

class Supervisor(object):
	def __init__(self, model, args, prog_vocab, prog_vocab_list):
		self.model = model
		self.eval_mode = args.eval_mode
		self.keep_last_n = args.keep_last_n
		self.dropout_rate = args.dropout_rate
		self.global_step = 0
		self.batch_size = args.batch_size
		self.model_dir = args.model_dir
		self.prog_format = args.prog_format
		self.prog_vocab = prog_vocab
		self.prog_vocab_list = prog_vocab_list
		self.prediction_output = args.prediction_output
		self.compatibility_mode = args.compatibility_mode

	def load_pretrained(self, load_model):
		print("Read model parameters from %s." % load_model)
		checkpoint = torch.load(load_model)
		if self.compatibility_mode:
			checkpoint = {k : v for k, v in checkpoint.items() if k.split(".")[0] != "value_func_model"}
		new_state_dict = dict()
		for key, value in checkpoint.items():
			name = key[7:]
			new_state_dict[name] = value
		self.model.load_state_dict(new_state_dict)

	def save_model(self):
		if not os.path.exists(self.model_dir):
			os.makedirs(self.model_dir)
		global_step_padded = format(self.global_step, '08d')
		ckpt_name = 'ckpt-' + global_step_padded
		path = os.path.join(self.model_dir, ckpt_name)
		ckpt = self.model.state_dict()
		torch.save(ckpt, path)

		if self.keep_last_n is not None:
			ckpts = []
			for file_name in os.listdir(self.model_dir):
				matched_name = CKPT_PATTERN.match(file_name)
				if matched_name is None or matched_name == ckpt_name:
					continue
				step = int(matched_name.group(1))
				ckpts.append((step, file_name))
			if len(ckpts) > self.keep_last_n:
				ckpts.sort()
				os.unlink(os.path.join(self.model_dir, ckpts[0][1]))


	def train(self, batch_data, feed_previous=False):
		self.model.dropout_rate = self.dropout_rate
		predictions_logit, predictions_prog = self.model(batch_data, feed_previous=feed_previous)
		
		if self.global_step == 0:
			self.model.module.optimizer.zero_grad()
		total_loss = self.model.module.compute_loss(batch_data, predictions_logit)
		total_loss.backward()
		
		self.model.module.train_step()
		self.model.module.optimizer.zero_grad()

		self.global_step += 1
		return total_loss.item()

	def compute_output_prog_acc(self, data, predictions, satisfied, simulator, random_data=None):
		exact_acc = 0
		generalization_acc = 0
		semantic_acc = 0
		functional_acc = 0
		semantic_functional_acc = 0
		n_data = len(data)
		res = {}
		correct_exact_idxes = []
		correct_generalization_idxes = []
		correct_semantic_idxes = []
		pred_progs = []
		for idx in range(n_data):
			data_line = data[idx]
			data_item = json.loads(data_line)
			prog = data_utils.read_prog(data_item, self.prog_format, self.prog_vocab)
			pred = predictions[idx]
			train_ios, val_io = data_utils.read_ios(data_item)
			cur_exact_acc = 1
			cur_pred_prog = []
			for time_step in range(len(pred)):
				if pred[time_step] == data_utils.PAD_ID:
					break
				cur_pred_prog.append(self.prog_vocab_list[pred[time_step]])
			for time_step in range(len(pred)):
				if time_step < len(prog) and pred[time_step] != prog[time_step]:
					cur_exact_acc = 0
					break
			cur_generalization_acc = 1
			cur_semantic_acc = 1
			for i, (val_i, val_o) in enumerate(train_ios + val_io):
				syntax_checker_state = simulator.syntax_checker.get_init_syntax_state()
				karel_world = karel_simulator.KarelExecutionState(val_i)
				for time_step in range(len(pred)):
					if pred[time_step] == data_utils.PAD_ID:
						break
					karel_world, syntax_checker_state = simulator.execute(karel_world, syntax_checker_state, pred[time_step])
					if karel_world.crash:
						cur_generalization_acc = 0
						if i < 5:
							cur_semantic_acc = 0
						break
				if cur_generalization_acc == 0:
					break
				target_karel_world = karel_simulator.KarelExecutionState(val_o)
				if karel_world != target_karel_world:
					cur_generalization_acc = 0
					if i < 5:
						cur_semantic_acc = 0
					break
			if random_data is not None:
				random_data_item = random_data['sources'][idx]
				val_ios = data_utils.read_random_ios(random_data_item)
				cur_functional_acc = 1
				for val_i, val_o in val_ios:
					syntax_checker_state = simulator.syntax_checker.get_init_syntax_state()
					karel_world = karel_simulator.KarelExecutionState(val_i)
					for time_step in range(len(pred)):
						if pred[time_step] == data_utils.PAD_ID:
							break
						karel_world, syntax_checker_state = simulator.execute(karel_world, syntax_checker_state, pred[time_step])
						if karel_world.crash:
							cur_functional_acc = 0
							break
					if cur_functional_acc == 0:
						break
					target_karel_world = karel_simulator.KarelExecutionState(val_o)
					if karel_world != target_karel_world:
						cur_functional_acc = 0
						break
				functional_acc += cur_functional_acc
				if cur_semantic_acc == 1 and cur_functional_acc == 1:
					semantic_functional_acc += 1
			if cur_semantic_acc == 1:
				correct_semantic_idxes.append(idx)
			if cur_generalization_acc == 1 or cur_exact_acc == 1:
				correct_generalization_idxes.append(idx)
			if cur_exact_acc == 1:
				correct_exact_idxes.append(idx)
			exact_acc += cur_exact_acc
			generalization_acc += cur_generalization_acc
			semantic_acc += cur_semantic_acc
			pred_progs.append(cur_pred_prog)
		exact_acc = exact_acc * 1.0 / n_data
		generalization_acc = generalization_acc * 1.0 / n_data
		semantic_acc = semantic_acc * 1.0 / n_data
		semantic_functional_acc = semantic_functional_acc * 1.0 / n_data
		res['gen'] = correct_generalization_idxes
		res['exact'] = correct_exact_idxes
		res['semantic'] = correct_semantic_idxes
		if random_data is not None:
			functional_acc = functional_acc * 1.0 / n_data
		if satisfied is not None:
			res['satisfied'] = satisfied
		res['pred'] = pred_progs
		json.dump(res, open(self.prediction_output, 'w'))
		if random_data is not None:
			return exact_acc, generalization_acc, semantic_acc, functional_acc, semantic_functional_acc
		return exact_acc, generalization_acc, semantic_acc

	def eval(self, data, eval_max_size=None, feed_previous=True, random_data=None):
		with torch.no_grad():
			self.model.eval()
			total_loss = 0
			predictions = []
			top_one_predictions = []
			satisfied = []

			pred_dis = []
			dis = []

			data_size = len(data)
			if eval_max_size is not None:
				data_size = min(data_size, eval_max_size)
			for start_idx in range(0, data_size, self.batch_size):
				print('executing: ', start_idx)
				batch_data = data_utils.get_batch(data[:data_size], self.batch_size, self.prog_format, self.prog_vocab, start_idx=start_idx)
				if self.eval_mode == 'search':
					cur_predictions_prog, cur_satisfied = self.model(batch_data, feed_previous=feed_previous)
					predictions = predictions + cur_predictions_prog
					for prog_pred_list in cur_predictions_prog:
						top_one_predictions = top_one_predictions + [prog_pred_list[0]]
					satisfied = satisfied + cur_satisfied
					continue
				cur_predictions_logit, cur_predictions_prog = self.model(batch_data, feed_previous=feed_previous)
				if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
					cur_loss = self.model.module.compute_loss(batch_data, cur_predictions_logit)
				else:
					cur_loss = self.model.compute_loss(batch_data, cur_predictions_logit)
				cur_loss = cur_loss.item()
				for batch_idx in range(len(batch_data)):
					cur_prog = []
					for time_step in range(len(cur_predictions_prog)):
						cur_prog.append(cur_predictions_prog[time_step][batch_idx])
					predictions.append(cur_prog)
					top_one_predictions.append(cur_prog)
				total_loss += cur_loss * len(cur_predictions_logit)
			top_one_predictions = np.array(top_one_predictions)
			if len(satisfied) == 0:
				satisfied = None
			if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
				acc = self.compute_output_prog_acc(data[:data_size], top_one_predictions, satisfied, self.model.module.simulator, random_data)
			else:
				acc = self.compute_output_prog_acc(data[:data_size], top_one_predictions, satisfied, self.model.simulator, random_data)
			if self.eval_mode == 'search':
				return (-1,) + acc
			total_loss /= len(top_one_predictions)
			self.model.train()
			return (total_loss,) + acc
