import argparse
import torch
import torch.nn as nn
from config import Config
import os, sys, math
from pre_train_dataset import PreTrainDataset, get_data_loader
from MetaPreTrainingDataset import get_h5py_nlp_dataset, get_h5py_nlp_data, _data_wrapper, TASK_NAME_LIST
import t5_model
import numpy as np
import random
import utils
from checkpointing import CheckpointManager
from tqdm import tqdm
import deepspeed
from tqdm.contrib.logging import logging_redirect_tqdm

def _average_all(tensor):
	# We copy because modification happens in-place
	averaged = tensor.detach().clone()
	# We use `all_reduce` because it is better supported than `reduce`
	torch.distributed.all_reduce(averaged, torch.distributed.ReduceOp.SUM)
	return averaged / torch.distributed.get_world_size()

def evaluation(_C, eval_data, model, device, is_root=True):
	model.eval()
	loss_list = []

	with logging_redirect_tqdm():
		with torch.no_grad():
			if is_root:
				pbar = tqdm(eval_data)
			else:
				pbar = eval_data
			
			for batch in pbar:
				
				for n in batch:
					if n in ['gt_x', 'gt_y', 'data_index']: continue
					batch[n] = batch[n].to(device)

				outputs = model(
				    input_ids=batch['encoder_input_ids'], 
				    attention_mask=batch['encoder_mask'], 
				    labels=batch['decoder_input_ids'],
				    task_ids=batch['task_ids'],
					task_type_ids=batch['task_type_ids'],
					prefix_ids=batch['prefix_ids'],
				)
				loss = outputs.loss
				loss_list.append(_average_all(loss).item())

	final_loss = sum(loss_list) / len(loss_list)

	if is_root:
		print("EVAL LOSS %.2f" % final_loss)

	return -1 * final_loss

parser = argparse.ArgumentParser("Train a MT5 for Machine Translation")
parser.add_argument(
    "--config", required=True, help="Path to a config file with all configuration parameters."
)
parser.add_argument(
    "--config-override",
    default=[],
    nargs="*",
    help="A sequence of key-value pairs specifying certain config arguments (with dict-like "
    "nesting) using a dot operator. The actual config will be updated and recorded in "
    "the serialization directory.",
)
parser.add_argument(
    "--serialization-dir",
    default=None,
    help="Path to a (non-existent) directory for serializing checkpoints and tensorboard logs.",
)
parser.add_argument(
    "--start-from-checkpoint",
    default=None,
    help="Path to load checkpoint and continue training [only supported for module_training].",
)
parser.add_argument(
    "--output-path",
    default=None,
    help="Path to save output captions",
)
parser.add_argument('--local_rank', type=int, default=-1,
                    help='local rank passed from distributed launcher')
group = parser.add_mutually_exclusive_group()
group.add_argument('--train', action='store_true')
group.add_argument('--validation', action='store_true')
parser = deepspeed.add_config_arguments(parser)

if __name__ == "__main__":
	_A = parser.parse_args()
	_C = Config(_A.config, _A.config_override)

	np.random.seed(_C.random_seed)
	random.seed(_C.random_seed)
	torch.manual_seed(_C.random_seed)
	torch.cuda.manual_seed_all(_C.random_seed)
	torch.backends.cudnn.benchmark = False
	torch.backends.cudnn.deterministic = True

	os.environ["TOKENIZERS_PARALLELISM"] = "false"
	os.environ["NCCL_DEBUG"] = "WARN"
	local_rank = _A.local_rank
	torch.cuda.set_device(local_rank)
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

	if _C.enable_pretrain_task_embeddings:
		_C.task_embed_count = len(TASK_NAME_LIST)
	else:
		_C.task_embed_count = 0

	if _A.deepspeed:
		deepspeed.init_distributed()

	is_root = (not _A.deepspeed) or torch.distributed.get_rank() == 0

	if _C.enable_full_finetune:
		tokenizer, model = t5_model.get_full_finetune_t5_model(_C)
	else:
		tokenizer, model = t5_model.get_t5_model(_C)

	if is_root: 
		total_parameter_count = 0
		trainable_parameter_count = 0
		for p in model.parameters():
			total_parameter_count += p.numel()
			if p.requires_grad:
				trainable_parameter_count += p.numel()
		print('Total Parameter Count %d' % total_parameter_count)
		print('Trainable Parameter Count %d' % trainable_parameter_count)

		print(_C)
		for arg in vars(_A):
			print("{:<20}: {}".format(arg, getattr(_A, arg)))

	if _C.enable_meta_training:
		if _A.deepspeed:
			val_batch_size = _C.batch_size // (torch.distributed.get_world_size() * _C.gradient_accumulation_steps)
		else:
			val_batch_size = 1
		dev_loader = get_h5py_nlp_data(_C, "validation", val_batch_size, tokenizer, _C.max_length, shuffle=False, distributed=_A.deepspeed, is_root=is_root, is_train=False)
	else:
		dev_data = PreTrainDataset(_C, tokenizer, _C.dev_path)
		dev_loader = get_data_loader(_C, dev_data, _C.batch_size, shuffle=False)

	if _C.enable_meta_training:
		train_data = get_h5py_nlp_dataset(_C, "train", tokenizer, _C.max_length, is_root=is_root, is_train=True)
	else:
		train_data = PreTrainDataset(_C, tokenizer, _C.train_path)
		train_loader = get_data_loader(_C, train_data, _C.batch_size, shuffle=True)

	ds_config = {
		"train_batch_size": _C.batch_size,
		"gradient_accumulation_steps": _C.gradient_accumulation_steps,
		"steps_per_print": 100 * _C.gradient_accumulation_steps,
		"fp16": {
		  "enabled": False
		},
	}

	if _C.enable_adam_opt:
		optimizer = utils.build_optimizer(_C, model)
	elif _C.enable_full_finetune:
		ds_config['optimizer'] = {
		  "type": "Adam",
		  "params": {
		      "lr": 5e-5
		  }
		}
		optimizer = None
	else:
		optimizer = utils.build_t5_optimizer(_C, model)

	dist_model, _, train_loader, _ = deepspeed.initialize(args=_A, model=model, model_parameters=[p for p in model.parameters() if p.requires_grad], config=ds_config, training_data=train_data, collate_fn=_data_wrapper, optimizer=optimizer)

	if _A.start_from_checkpoint is not None:
		dist_model.load_checkpoint(_A.start_from_checkpoint, load_module_only=_C.only_load_module)
		evaluation(_C, dev_loader, dist_model, device, is_root=is_root)

	if _C.enable_new_task_embeddings:
		dist_model.module.update_task_embedding(1)

	if _A.validation:
		gen_data(_C, dev_loader, model, device, tokenizer, output_path=_A.output_path, is_root=is_root)

	if _A.train:
		train_iter = iter(train_loader)

		if _C.num_training_steps == 0:
			_C.num_training_steps = _C.max_epoch * len(train_iter) // (_C.batch_size // torch.distributed.get_world_size())
		epoch_num = math.ceil(_C.num_training_steps / _C.checkpoint_every_step)

		os.makedirs(_A.serialization_dir, exist_ok=True)
		_C.dump(os.path.join(_A.serialization_dir, "config.yml"))

		eval_every = _C.checkpoint_every_step
		total_step = 0
		raw_total_step = 0
		lowest_loss = -1e10

		for epoch in range(epoch_num):
			run_step = eval_every if total_step + eval_every < _C.num_training_steps else  _C.num_training_steps - total_step
			dist_model.train()

			if is_root:
				print('EPOCH %d / %d' % (epoch + 1, epoch_num))
				pbar = tqdm(total=run_step, file=sys.stdout)

			loss_list = []

			with logging_redirect_tqdm():
				for _ in range(run_step * _C.gradient_accumulation_steps):
					try:
						batch = next(train_iter)
					except:
						train_iter = iter(train_loader)
						batch = next(train_iter)

					for n in batch:
						if n in ['gt_x', 'gt_y', 'data_index']: continue
						batch[n] = batch[n].to(dist_model.local_rank)

					outputs = dist_model(
					    input_ids=batch['encoder_input_ids'], 
					    attention_mask=batch['encoder_mask'], 
					    labels=batch['decoder_input_ids'],
						task_ids=batch['task_ids'],
						task_type_ids=batch['task_type_ids'],
						prefix_ids=batch['prefix_ids'],
					)
					loss = outputs.loss
					dist_model.backward(loss)
					dist_model.step()
					raw_total_step += 1

					loss_list.append(_average_all(loss).item())

					if is_root and raw_total_step % _C.gradient_accumulation_steps == 0:
						ave_loss = sum(loss_list) / len(loss_list)
						pbar.set_description("loss %.2f" % ave_loss)
						pbar.update(1)
						pbar.refresh()
						total_step += 1
						loss_list = []
			
			if is_root:
				pbar.close()

			_score = evaluation(_C, dev_loader, dist_model, device, is_root=is_root)
			if _score >= lowest_loss:
				lowest_loss = _score
				dist_model.save_checkpoint(_A.serialization_dir, "model_epoch_%d" % (epoch + 1))

			
				