import argparse
import torch
import torch.nn as nn
from config import Config
import os, sys, math
from MetaT5Datasets import get_single_h5py_nlp_data, _data_wrapper, TASK_NAME_LIST
import t5_model
import numpy as np
import random
import utils
from tqdm import tqdm
import deepspeed
import re
from sklearn.metrics import f1_score as sen_f1_score
from sklearn.metrics import accuracy_score
import json
from evaluate_utils import *

def _average_all(tensor):
	# We copy because modification happens in-place
	averaged = tensor.detach().clone()
	# We use `all_reduce` because it is better supported than `reduce`
	torch.distributed.all_reduce(averaged, torch.distributed.ReduceOp.SUM)
	return averaged / torch.distributed.get_world_size()

compiler = re.compile(r'<extra_id_0>(.*)<extra_id_1>')
def _extract(text):
	match = compiler.search(text)
	if match is None:
		text = text.strip().replace("<extra_id_0>", "").replace("<extra_id_1>", "")
		return text
	else:
		return match.group(1).strip().replace("<extra_id_0>", "").replace("<extra_id_1>", "")

def get_label_list(label_path):
	label_list = []
	with open(label_path) as out:
		for l in out.readlines():
			label = l.strip()
			label_list.append(' '.join(label.split('_')))
	return label_list

def is_match(sub_sentence, generated_entity):
	if sub_sentence == generated_entity:
		return True

	if sub_sentence.startswith(generated_entity) or sub_sentence.endswith(generated_entity):
		return True

	return False

def clean_up_t5_tokenizer(tokenizer, sen_np):
    str_sen = ""
    for t in tokenizer.convert_ids_to_tokens(sen_np):
        t = t.replace(chr(9601), ' ')
        str_sen += t
    return str_sen.replace('</s>', '').replace('<pad>', '').strip()

def get_labels(label_list, gen_label):
	if gen_label in label_list:
		return gen_label
	for l in label_list:
		if l in gen_label:
			return l
	return None

def gen_for_nlu(_C, eval_data, model, device, tokenizer, output_path=None, is_root=True):
	local_output = []
	model.eval()
	eval_iter = iter(eval_data)
	label_list = get_label_list(_C.label_path)

	output_lines = []
	with torch.no_grad():

		if is_root:
			pbar = tqdm(eval_data)
		else:
			pbar = eval_data

		for batch in pbar:
			
			for n in batch:
				if n in ['gt_x', 'gt_y', 'data_index']: continue
				batch[n] = batch[n].to(device)

			outputs = model.generate(
				input_ids=batch['encoder_input_ids'], 
				attention_mask=batch['encoder_mask'], 
				task_ids=batch['task_ids'],
				task_type_ids=batch['task_type_ids'],
				max_length=15,
				min_length=1,
				eos_token_id=tokenizer.eos_token_id,
				num_return_sequences=1, 
				num_beams=1,
				early_stopping=True,
			)


			batch_lines = []

			for i in range(batch['decoder_input_ids'].size(0)):
				gt_x, gt_y, data_index = batch['gt_x'][i], batch['gt_y'][i], batch['data_index'][i]
				output_seq = _extract(clean_up_t5_tokenizer(tokenizer, outputs[i]))
				batch_lines.append((gt_x, gt_y, output_seq, data_index))

			all_batch_lines = [None for _ in range(torch.distributed.get_world_size())]
			torch.distributed.all_gather_object(all_batch_lines, batch_lines)
			for rank_batch_line in all_batch_lines:
				output_lines += rank_batch_line

	gen_labels, gt_labels = {}, {}
	total_instances, correct_instances = 0, 0

	def most_frequent(List):
		return max(set(List), key=List.count)

	if _C.running_task == 'wsc' or _C.running_task == 'record':
		# generative task
		for (gt_x, gt_y, output_seq, data_index) in output_lines:
			total_instances += 1
			gen_label = output_seq
			if data_index not in gen_labels:
				gen_labels[data_index] = []
				gt_labels[data_index] = None
			if gen_label in gt_x:
				correct_instances += 1
			gen_labels[data_index].append((gt_x, gen_label))
			gt_labels[data_index] = gt_y

		final_gt, final_gen, final_x, idxs = [], [], [], []
		for data_index in gt_labels:
			idxs.append(data_index)
			gt_x = gen_labels[data_index][0][0]
			pred_list = [x[1] for x in gen_labels[data_index]]
			final_gt.append(gt_labels[data_index])
			final_gen.append(most_frequent(pred_list))
			final_x.append(gt_x)

	else:
		for (gt_x, gt_y, output_seq, data_index) in output_lines:
			total_instances += 1
			gen_label = get_labels(label_list, output_seq)
			if data_index not in gen_labels:
				gen_labels[data_index] = []
				gt_labels[data_index] = None
			if gen_label is not None:
				correct_instances += 1
				gen_labels[data_index].append((gt_x, label_list.index(gen_label)))
			else:
				# gen_labels[data_index].append((gt_x, random.choice([i for i in range(len(label_list))])))
				gen_labels[data_index].append((gt_x, 0))

			gt_labels[data_index] = label_list.index(gt_y)

		final_gt, final_gen, final_x, idxs = [], [], [], []
		for data_index in gt_labels:
			idxs.append(data_index)
			gt_x = gen_labels[data_index][0][0]
			pred_list = [x[1] for x in gen_labels[data_index]]
			final_gt.append(gt_labels[data_index])
			final_gen.append(most_frequent(pred_list))
			final_x.append(gt_x)
	metric_dict = compute_metrics(_C.running_task, final_gt, final_gen, data_index=idxs)
	
	if is_root:
		print("Correct Ratio %.2f" % (100 * correct_instances / total_instances))
		performance_str = f"Instances {len(output_lines)} ==> "
		for k, v in metric_dict.items():
			performance_str += '{} {:.2f}, '.format(k, v)
		# print("Instances %d ==> mic-F1 %.2f, mac-F1 %.2f, Accuracy %.2f" % (len(output_lines), new_F, new_F_macro, acc))
		print(performance_str)

		if output_path is not None:
			with open(output_path, 'w') as out:
				for input_x, gen_y, gt_y, idx in zip(final_x, final_gen, final_gt, idxs):
					# if _C.enable_consistency_filtering and (not gen_y == gt_y): continue
					# if _C.enable_pair_sentence_classification:
					# 	text1, text2 = input_x
					# 	gen_label = label_list[gen_y]
					if _C.running_task == 'wsc':
						out.write('%s\t%s\t%s\t%s\n' % (str(input_x), str(gen_y), str(gt_y), str(idx)))
					elif _C.running_task == 'record':
						raise NotImplementedError()
					else:
						out.write('%s\t%s\t%s\t%s\n' % (str(input_x), str(label_list[gen_y]), str(label_list[gt_y]), str(idx)))
	# return new_F, new_F_macro, acc
	return metric_dict


def gen_for_nlu_record_eval(_C, eval_data, model, device, output_path=None, is_root=True):
	local_output = []
	model.eval()
	eval_iter = iter(eval_data)
	label_list = get_label_list(_C.label_path)

	output_lines = []
	with torch.no_grad():

		if is_root:
			pbar = tqdm(eval_data)
		else:
			pbar = eval_data

		for batch in pbar:

			for n in batch:
				if n in ['gt_x', 'gt_y', 'data_index']: continue
				batch[n] = batch[n].to(device)
			outputs = dist_model(
				input_ids=batch['encoder_input_ids'],
				attention_mask=batch['encoder_mask'],
				labels=batch['decoder_input_ids'],
				task_ids=batch['task_ids'],
				task_type_ids=batch['task_type_ids']
			)
			loss = outputs.loss

			batch_lines = []

			for i in range(batch['decoder_input_ids'].size(0)):
				gt_x, gt_y, data_index = batch['gt_x'][i], batch['gt_y'][i], batch['data_index'][i]
				batch_lines.append((gt_x, gt_y, loss[i], data_index))

			all_batch_lines = [None for _ in range(torch.distributed.get_world_size())]
			torch.distributed.all_gather_object(all_batch_lines, batch_lines)
			for rank_batch_line in all_batch_lines:
				output_lines += rank_batch_line

	gen_labels, gt_labels = {}, {}
	total_instances, correct_instances = 0, 0

	for (gt_x, gt_y, loss_i, data_index) in output_lines:
		total_instances += 1
		if data_index not in gen_labels:
			gen_labels[data_index] = []
			gt_labels[data_index] = None
		gen_labels[data_index].append((gt_x, loss_i.item()))
		gt_labels[data_index] = gt_y

	final_gt, final_loss, final_x, idxs = [], [], [], []
	for data_index in gt_labels:
		idxs.append(data_index)
		gt_x = gen_labels[data_index][0][0]
		pred_list = [x[1] for x in gen_labels[data_index]]
		final_gt.append(gt_labels[data_index])
		final_loss.append(pred_list[0])
		final_x.append(gt_x)

	metric_dict = compute_metrics_record_evaluate(candidate=final_gt, loss=final_loss, data_index=idxs,
												  output_path=output_path)

	if is_root:
		print("Correct Ratio %.2f" % (100 * correct_instances / total_instances))
		performance_str = f"Instances {len(output_lines)} ==> "
		for k, v in metric_dict.items():
			performance_str += '{} {:.2f}, '.format(k, v)
		print(performance_str)

	return metric_dict


parser = argparse.ArgumentParser("Train a MT5 for Machine Translation")
parser.add_argument(
    "--config", required=True, help="Path to a config file with all configuration parameters."
)
parser.add_argument(
    "--config-override",
    default=[],
    nargs="*",
    help="A sequence of key-value pairs specifying certain config arguments (with dict-like "
    "nesting) using a dot operator. The actual config will be updated and recorded in "
    "the serialization directory.",
)
parser.add_argument(
    "--serialization-dir",
    default=None,
    help="Path to a (non-existent) directory for serializing checkpoints and tensorboard logs.",
)
parser.add_argument(
    "--start-from-checkpoint",
    default=None,
    help="Path to load checkpoint and continue training [only supported for module_training].",
)
parser.add_argument(
    "--output-path",
    default=None,
    help="Path to save output captions",
)
parser.add_argument('--local_rank', type=int, default=-1,
                    help='local rank passed from distributed launcher')
group = parser.add_mutually_exclusive_group()
group.add_argument('--train', action='store_true')
group.add_argument('--validation', action='store_true')
group.add_argument('--test', action='store_true')
parser = deepspeed.add_config_arguments(parser)

if __name__ == "__main__":
	_A = parser.parse_args()
	_C = Config(_A.config, _A.config_override)

	np.random.seed(_C.random_seed)
	random.seed(_C.random_seed)
	torch.manual_seed(_C.random_seed)
	torch.cuda.manual_seed_all(_C.random_seed)
	torch.backends.cudnn.benchmark = False
	torch.backends.cudnn.deterministic = True

	os.environ["TOKENIZERS_PARALLELISM"] = "false"
	os.environ["NCCL_DEBUG"] = "WARN"
	local_rank = _A.local_rank
	torch.cuda.set_device(local_rank)
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

	if _A.deepspeed:
		deepspeed.init_distributed()

	is_root = (not _A.deepspeed) or torch.distributed.get_rank() == 0

	if _C.load_from_pretrained:
		if _C.enable_pretrain_task_embeddings:
			_C.task_embed_count = len(TASK_NAME_LIST)
	else:
		if _C.enable_new_task_embeddings:
			_C.task_embed_count = _C.prefix_set_number if _C.prefix_set_number > 0 else 1

	if _C.enable_full_finetune:
		tokenizer, model = t5_model.get_full_finetune_t5_model(_C)
	elif _C.enable_full_pretrain:
		tokenizer, model = t5_model.get_full_pretrain_t5_model(_C)
	else:
		tokenizer, model = t5_model.get_t5_model(_C)

	if _A.deepspeed:
		val_batch_size = _C.val_batch_size // (torch.distributed.get_world_size() * _C.gradient_accumulation_steps)
	else:
		val_batch_size = _C.val_batch_size
	dev_loader = get_single_h5py_nlp_data(_C, _C.dev_path, _C.train_path, "validation", val_batch_size, tokenizer, _C.max_length, shuffle=True, distributed=_A.deepspeed, is_root=is_root, is_train=False, random_in_context=_C.random_in_context, duplication=_C.duplication)
	test_loader = get_single_h5py_nlp_data(_C, _C.test_path, _C.train_path, "validation", val_batch_size, tokenizer, _C.max_length, shuffle=True, distributed=_A.deepspeed, is_root=is_root, is_train=False, random_in_context=_C.random_in_context, duplication=_C.duplication)

	if _A.deepspeed:
		train_batch_size = _C.batch_size // (torch.distributed.get_world_size() * _C.gradient_accumulation_steps)
	else:
		train_batch_size = _C.batch_size
	train_loader = get_single_h5py_nlp_data(_C, _C.train_path, _C.train_path, "train", train_batch_size, tokenizer, _C.max_length, shuffle=True, distributed=_A.deepspeed, is_root=is_root, is_train=True, random_in_context=_C.random_in_context)

	ds_config = {
		"train_batch_size": _C.batch_size,
		"gradient_accumulation_steps": _C.gradient_accumulation_steps,
		"steps_per_print": 100,
		"fp16": {
		  "enabled": False
		},
	}

	if _C.enable_adam_opt:
		optimizer = utils.build_optimizer(_C, model)
	elif _C.enable_full_finetune:
		optimizer = utils.build_adam_optimizer(_C, model)
	else:
		optimizer = utils.build_t5_optimizer(_C, model)

	dist_model, _, _, _ = deepspeed.initialize(args=_A, model=model, model_parameters=[p for p in model.parameters() if p.requires_grad], config=ds_config, optimizer=optimizer)
	if _A.start_from_checkpoint is not None:
		dist_model.load_checkpoint(_A.start_from_checkpoint, load_module_strict=False, load_module_only=True)
	if _C.enable_new_task_embeddings and _C.load_from_pretrained:
		dist_model.module.update_task_embedding(_C.prefix_set_number if _C.prefix_set_number > 0 else 1)

	if is_root: 
		total_parameter_count = 0
		trainable_parameter_count = 0
		for p in model.parameters():
			total_parameter_count += p.numel()
			if p.requires_grad:
				trainable_parameter_count += p.numel()
		print('Total Parameter Count %d' % total_parameter_count)
		print('Trainable Parameter Count %d' % trainable_parameter_count)

		print(_C)
		for arg in vars(_A):
			print("{:<20}: {}".format(arg, getattr(_A, arg)))

	if _A.validation or _A.test:
		if _C.enable_nlu:
			_scores = gen_for_nlu(_C, dev_loader if _A.validation else test_loader, model, device, tokenizer, output_path=_A.output_path, is_root=is_root)
	if _A.train:
		train_iter = iter(train_loader)

		if _C.num_training_steps == 0:
			length_train_data = len(train_iter) // (_C.batch_size // torch.distributed.get_world_size())
			_C.num_training_steps = int(length_train_data * _C.max_epoch / _C.gradient_accumulation_steps)
		epoch_num = math.ceil(_C.num_training_steps / _C.checkpoint_every_step)

		os.makedirs(_A.serialization_dir, exist_ok=True)
		_C.dump(os.path.join(_A.serialization_dir, "config.yml"))

		eval_every = _C.checkpoint_every_step * _C.gradient_accumulation_steps
		total_step = 0
		from evaluate_utils import task_metrics
		lowest_loss = [-1e10] * len(task_metrics[_C.running_task]) if _C.running_task != 'record' else [-1e10, -1e10]
		best_test_performance = None

		for epoch in range(epoch_num):
			run_step = eval_every if total_step + eval_every < _C.num_training_steps * _C.gradient_accumulation_steps else  _C.num_training_steps * _C.gradient_accumulation_steps - total_step
			dist_model.train()

			if is_root:
				print('EPOCH %d / %d' % (epoch + 1, epoch_num))
				pbar = tqdm(total=math.ceil(run_step / _C.gradient_accumulation_steps), file=sys.stdout)

			for step in range(run_step):
				try:
					batch = next(train_iter)
				except:
					train_iter = iter(train_loader)
					batch = next(train_iter)

				for n in batch:
					if n in ['gt_x', 'gt_y', 'data_index']: continue
					batch[n] = batch[n].to(dist_model.local_rank)
				total_step += 1
				outputs = dist_model(
				    input_ids=batch['encoder_input_ids'], 
				    attention_mask=batch['encoder_mask'], 
				    labels=batch['decoder_input_ids'],
				    task_ids=batch['task_ids'],
					task_type_ids=batch['task_type_ids']
				)
				loss = outputs.loss
				# record temp
				loss = torch.mean(loss)
				dist_model.backward(loss)
				dist_model.step()

				ave_loss = _average_all(loss).item()

				if is_root:
					pbar.set_description("loss %.2f" % (ave_loss * _C.gradient_accumulation_steps))
					pbar.update(1)
					pbar.refresh()
			
			if is_root:
				pbar.close()

			if _C.enable_nlu:

				_scores = gen_for_nlu(_C, dev_loader, model, device, tokenizer, output_path=_A.output_path, is_root=is_root) \
					if _C.running_task != 'record' else gen_for_nlu_record_eval(
					_C, dev_loader, model, device, is_root=is_root)
				_keys = [k for k, v in _scores.items()]
				_scores = [v for k, v in _scores.items()]

			if sum(_scores) >= sum(lowest_loss):
				lowest_loss = _scores
				if _C.save_model_each_epoch:
					dist_model.save_checkpoint(_A.serialization_dir, "model_epoch_%d" % (epoch + 1))

				if _C.enable_nlu:
					if _C.dev_path == _C.test_path:
						best_test_performance = _scores
					else:
						best_test_performance = gen_for_nlu(_C, test_loader, model, device, tokenizer, output_path=_A.output_path, is_root=is_root) \
							if _C.running_task != 'record' else gen_for_nlu_record_eval(
							_C, dev_loader, model, device, output_path=_A.output_path, is_root=is_root)
						best_test_performance = [v for k, v in best_test_performance.items()]

			if _C.enable_nlu:
				test_performance_str = "Best Test "
				for i, v in enumerate(best_test_performance):
					test_performance_str += '{} {:.2f}, '.format(_keys[i], v)
				print(test_performance_str)
				# print("Best Test mic-F1 %.2f, mac-F1 %.2f, acc %.2f" % best_test_performance)


			
				