import torch
import numpy as np

from transformers import AutoTokenizer, AutoConfig
import os
import hjson
from argparse import ArgumentParser
from torch.utils.tensorboard import SummaryWriter
import deepspeed
from liger_kernel.transformers import AutoLigerKernelForCausalLM
import torch.distributed as dist
from peft import LoraConfig, get_peft_model


from cot_datasets import CotDataset, collate_fn
from model import SynAdapt, set_seed
from utils import get_lora_param_groups_cpuOpti, get_lora_param_groups, wrap_distributed_model



if __name__ == '__main__':
	def str2bool(v):
		if v in ['True', 'true', 'TRUE']:
			return True
		else:
			return False

	set_seed(42)

	parser = ArgumentParser()
	parser.add_argument('--local_rank', type=int, default=-1,help='local rank passed from distributed launcher')

	parser.add_argument("--Qwen_model_path", type=str, default=None, help="repo id")

	parser.add_argument("--log_dir", type=str, default="", help="repo id")
	parser.add_argument("--save_dir", type=str, default="", help="repo id")
	parser.add_argument("--time_str", type=str, default="", help="repo id")

	parser.add_argument("--epoch", type=int, default=3, help="repo id")
	parser.add_argument("--save_step", type=int, default=512, help="repo id")
	parser.add_argument("--micro_batch_size", type=int, default=1, help="repo id")
	parser.add_argument("--n_accumulation_steps", type=int, default=1, help="repo id")

	parser.add_argument('--activation_ckpt', type=str2bool, default="False", help='local rank passed from distributed launcher')
	parser.add_argument("--use_optimizer", type=str2bool, default="False", help="repo id")
	parser.add_argument("--use_scheduler", type=str2bool, default="True", help="repo id")

	#! ccot related
	parser.add_argument("--cot_cnt", type=int, default=512, help="repo id")
	parser.add_argument("--time_step", type=int, default=4, help="repo id")
	parser.add_argument("--raw_train_path", type=str, default=None, help="repo id")
	parser.add_argument("--index_train_path", type=str, default=None, help="repo id")
	parser.add_argument("--target_ccot_dir", type=str, default=None, help="repo id")

	parser = deepspeed.add_config_arguments(parser)
	args = parser.parse_args()

	#########################################
	# Prepare parameter
	#########################################
	with open(args.deepspeed_config, 'r') as file:
		ds_config = hjson.load(file)
	ds_config['gradient_accumulation_steps'] = args.n_accumulation_steps
	ds_config['train_micro_batch_size_per_gpu'] = args.micro_batch_size
	if not args.use_optimizer:
		del ds_config['optimizer']
	if not args.use_scheduler:
		del ds_config['scheduler']
	
	save_dir = os.path.join(args.save_dir, f"{args.time_str}")
	if args.local_rank == 0:
		if args.reload_writer_dir != "":
			writer = SummaryWriter(args.reload_writer_dir)
			print(f"[Log] into {args.reload_writer_dir}")
		else:
			writer = SummaryWriter(os.path.join(args.log_dir, f'runs_{args.time_str}'))
			print(f"[Log] into {os.path.join(args.log_dir, f'runs_{args.time_str}')}")
		if not os.path.exists(save_dir):
			os.makedirs(save_dir)


	#########################################
	# Prepare model and dataset
	#########################################
	# prepare model
	qwen_config = AutoConfig.from_pretrained(args.Qwen_model_path, trust_remote_code=True,)
	qwen_config._attn_implementation = "flash_attention_2"
	qwen_tokenizer = AutoTokenizer.from_pretrained(args.Qwen_model_path, trust_remote_code=True)
	qwen_model = AutoLigerKernelForCausalLM.from_pretrained(args.Qwen_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, config=qwen_config)
	lora_config = LoraConfig(
		r=8,
		lora_alpha=32,
		target_modules=["q_proj", "v_proj"],
		lora_dropout=0.1,
		bias="none",
		task_type="CAUSAL_LM"
	)
	qwen_lora_model = get_peft_model(qwen_model, lora_config, adapter_name="lora_ccot")
	synAdapt = SynAdapt(qwen_lora_model, qwen_tokenizer)
	
	target_cnt = 0
	for name, param in synAdapt.qwen_lora_model.named_parameters():
		if "lm_head" in name:
			param.requires_grad = True
			target_cnt+=1
		elif "embed_tokens" in name:
			param.requires_grad = True
			target_cnt+=1
	if target_cnt != 2:
		raise Exception("Embedding name is not correct!")
	else:
		print(f">>> [Embedding Trainable] Setting embedding to be trainable sucessfully!!!")
	synAdapt.qwen_lora_model.print_trainable_parameters()
	synAdapt.to(torch.bfloat16)
	
	# prepare dataset
	train_dataset = CotDataset(qwen_tokenizer, 
									args.index_train_path, args.raw_train_path, args.target_ccot_dir,
									cot_length = args.cot_cnt, time_step = args.time_step)
	print(f"Train dataset size: {len(train_dataset)}")

	# deepspeed offload
	ddp_train_loader, ddp_engine, ddp_optimizer, ddp_scheduler = wrap_distributed_model(
		args,
		train_dataset=train_dataset,
		collate_fn = lambda x: collate_fn(x, qwen_tokenizer),
		model=synAdapt,
		ds_config = ds_config
	)
	device = ddp_engine.device
	for i, group in enumerate(ddp_optimizer.param_groups):
		print(f"Group {i}: LR = {group['lr']}, Params = {len(group['params'])}")


	#########################################
	# Begin Training
	#########################################
	for epoch_id in range(args.epochs):
		if args.local_rank== 0:
			print("===============================", flush=True)
			print(f"Epoch {epoch_id}")
			print("===============================")
		loss_align_list = []
		loss_ans_ce_list = []
		loss_final_l1_list = []

		input_length_list = []
		for idx, batch in enumerate(ddp_train_loader):
			input_length_list.append(batch[1][-1])
			if args.local_rank== 0 and idx % 16 == 0:
				print("*******************************", flush=True)
				print(f"[Batch Q] {batch[1][0][0]}")
				print(f"[Batch A] {batch[1][1][0]}")
				print("*******************************")
				print(f"Max length: {np.max(input_length_list)}, Avg length: {np.mean(input_length_list)}", flush=True)
				input_length_list = []

			inputs, text_inputs = batch
			question_ids = inputs[0].to(ddp_engine.device)
			question_masks = inputs[1].to(ddp_engine.device)
			target_dcot_ids = inputs[4].to(ddp_engine.device)
			target_dcot_mask = inputs[5].to(ddp_engine.device)
			target_ccot_list = [[iitem.to(ddp_engine.device) for iitem in item] for item in inputs[6]]
			cot_fill_ids = inputs[7].to(ddp_engine.device)
			cot_fill_mask = inputs[8].to(ddp_engine.device)
			end_think_ids = inputs[9].to(ddp_engine.device)
			end_think_mask = inputs[10].to(ddp_engine.device)
			answer_ids = inputs[11].to(ddp_engine.device)
			answer_mask = inputs[12].to(ddp_engine.device)
			torch.cuda.empty_cache()

			loss_align, loss_ans_ce, loss_final_l1 = ddp_engine.forward(question_ids, question_masks,
																	answer_ids, answer_mask,
																	target_dcot_ids, target_dcot_mask, target_ccot_list,
																	cot_fill_ids, cot_fill_mask,
																	end_think_ids, end_think_mask,
																		time_step = args.time_step)
			ddp_engine.backward(loss_align + loss_ans_ce + loss_final_l1)
			ddp_engine.step()
			torch.cuda.empty_cache()

			# training log
			loss_align_list.append(loss_align.detach().cpu().item())
			loss_ans_ce_list.append(loss_ans_ce.detach().cpu().item())
			loss_final_l1_list.append(loss_final_l1.detach().cpu().item())
			torch.cuda.empty_cache()
			if args.local_rank== 0 and idx % args.n_accumulation_steps == 0:
				lrs = list(ddp_engine.get_lr())
				print(f"[Batch-{idx}/{len(ddp_train_loader)}]--Training Learning Rate: {lrs}", flush=True)
				print(f"[Batch-{idx}/{len(ddp_train_loader)}]--Training Align Loss: {sum(loss_align_list) / len(loss_align_list)}", flush=True)
				print(f"[Batch-{idx}/{len(ddp_train_loader)}]--Training Ans Loss: {sum(loss_ans_ce_list) / len(loss_ans_ce_list)}", flush=True)
				print(f"[Batch-{idx}/{len(ddp_train_loader)}]--Training Final L1 Loss: {sum(loss_final_l1_list) / len(loss_final_l1_list)}", flush=True)

			if (idx % args.save_step == 0 and idx !=0) or idx == len(ddp_train_loader)-1:
				if args.local_rank==0:
					print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
					print(f"<<Save Step Log...>>")
					print(f"[Batch-{idx}/{len(ddp_train_loader)}]--AVG Training Step L1 Loss: {sum(loss_align_list) / len(loss_align_list)}", flush=True)
					print(f"[Batch-{idx}/{len(ddp_train_loader)}]--AVG Training Ans Loss: {sum(loss_ans_ce_list) / len(loss_ans_ce_list)}", flush=True)
					print(f"[Batch-{idx}/{len(ddp_train_loader)}]--AVG Training Final L1 Loss: {sum(loss_final_l1_list) / len(loss_final_l1_list)}", flush=True)
					print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
				loss_align_list = []
				loss_ans_ce_list = []
				loss_final_l1_list = []

				tag = f"checkpoint_{epoch_id * len(ddp_train_loader) + idx}"
				ddp_engine.save_checkpoint(save_dir, tag=tag)
				if args.local_rank == 0:
					print(f"Save checkpoint into {save_dir}, Tag: {tag}", flush=True)
				torch.cuda.empty_cache()
				dist.barrier()
			
