import torch
import numpy as np
from transformers import AutoTokenizer, AutoConfig
import os
import hjson
from argparse import ArgumentParser
import deepspeed
from liger_kernel.transformers import AutoLigerKernelForCausalLM
import torch.distributed as dist
from peft import LoraConfig, get_peft_model
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint


from model import SynAdapt, set_seed
from cot_datasets import CotDataset_hardness, collate_fn_hardness
from utils import get_lora_param_groups_cpuOpti, get_lora_param_groups, wrap_distributed_model, print_trainable_parameters


if __name__ == '__main__':

	def str2bool(v):
		if v in ['True', 'true', 'TRUE']:
			return True
		else:
			return False

	set_seed(42)

	parser = ArgumentParser()
	parser.add_argument('--local_rank', type=int, default=-1,help='local rank passed from distributed launcher')

	parser.add_argument("--Qwen_model_path", type=str, default=None, help="repo id")
	parser.add_argument("--LLM_state_dict_dir", type=str, default=None, help="repo id")
	parser.add_argument("--LLM_state_dict_tag", type=str, default=None, help="repo id")

	parser.add_argument("--binary_path", type=str, default=None, help="repo id")

	
	parser.add_argument("--epoch", type=int, default=3, help="repo id")
	parser.add_argument("--save_step", type=int, default=512, help="repo id")
	parser.add_argument("--time_str", type=str, default="", help="repo id")
	parser.add_argument("--micro_batch_size", type=int, default=1, help="repo id")
	parser.add_argument("--n_accumulation_steps", type=int, default=1, help="repo id")
	parser.add_argument("--use_optimizer", type=str2bool, default="False", help="repo id")
	parser.add_argument("--use_scheduler", type=str2bool, default="True", help="repo id")

	parser.add_argument("--save_dir", type=str, default="", help="repo id")

	# CCoT related
	parser.add_argument("--cot_length", type=int, default=512, help="repo id")
	parser.add_argument("--time_step", type=int, default=4, help="repo id")

	parser = deepspeed.add_config_arguments(parser)

	args = parser.parse_args()
	with open(args.deepspeed_config, 'r') as file:
		ds_config = hjson.load(file)
	ds_config['gradient_accumulation_steps'] = args.n_accumulation_steps
	ds_config['train_micro_batch_size_per_gpu'] = args.micro_batch_size
	ds_config["optimizer"]["params"]["lr"] = 5e-4
	if not args.use_optimizer:
		del ds_config['optimizer']
	if not args.use_scheduler:
		del ds_config['scheduler']
	save_dir = os.path.join(args.save_dir, f"{args.time_str}")
	if args.local_rank == 0:
		if not os.path.exists(save_dir):
			os.makedirs(save_dir)

	#########################################
	# Prepare model
	#########################################
	qwen_config = AutoConfig.from_pretrained(args.Qwen_model_path, trust_remote_code=True,)
	qwen_config._attn_implementation = "flash_attention_2"
	qwen_tokenizer = AutoTokenizer.from_pretrained(args.Qwen_model_path, trust_remote_code=True)
	qwen_model = AutoLigerKernelForCausalLM.from_pretrained(args.Qwen_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, config=qwen_config)
	lora_config = LoraConfig(
		r=8,
		lora_alpha=32,
		target_modules=["q_proj", "v_proj"],
		lora_dropout=0.1,
		bias="none",
		task_type="CAUSAL_LM"
	)
	qwen_lora_model = get_peft_model(qwen_model, lora_config, adapter_name="lora_ccot")
	synAdapt = SynAdapt(qwen_lora_model, qwen_tokenizer)
	combined_model = get_fp32_state_dict_from_zero_checkpoint(args.state_dict_dir, tag = args.state_dict_tag)
	missing_keys, unexpected_keys = synAdapt.load_state_dict(combined_model, strict=False)
	if unexpected_keys:
		raise Exception("⚠️ 忽略无关参数:", unexpected_keys)

	# freeze LLM
	for name, param in synAdapt.qwen_lora_model.named_parameters():
		param.requires_grad = False
	print_trainable_parameters(synAdapt)
	
	#########################################
	# Prepare dataset
	#########################################
	train_dataset = CotDataset_hardness(qwen_tokenizer, args.binary_path, cot_length=args.cot_length)
	print(f"Train dataset size: {len(train_dataset)}")

	ddp_train_loader, ddp_engine, ddp_optimizer, ddp_scheduler = wrap_distributed_model(
		args,
		train_dataset=train_dataset,
		collate_fn = lambda x: collate_fn_hardness(x, qwen_tokenizer),
		model=synAdapt,
		ds_config = ds_config
	)
	device = ddp_engine.device
	for i, group in enumerate(ddp_optimizer.param_groups):
		print(f"Group {i}: LR = {group['lr']}, Params = {len(group['params'])}")

	for epoch_id in range(args.epochs):
		if args.local_rank== 0:
			print("===============================", flush=True)
			print(f"Epoch {epoch_id}")
			print("===============================")
		loss_list = []
		for idx, batch in enumerate(ddp_train_loader):
			if args.local_rank== 0 and idx % 16 == 0:
				print("*******************************", flush=True)
				print(f"[Batch win Q] {batch[1][0][0]}")
				print(f"[Batch lose Q] {batch[1][1][0]}")
				print("*******************************")
			
			# 需要将一些耗cpu的操作，提前处理完成（放到dataloader中处理）或者离线处理
			inputs, text_inputs = batch
			win_prompt_ids = inputs[0].to(ddp_engine.device)
			win_prompt_masks = inputs[1].to(ddp_engine.device)
			lose_prompt_ids = inputs[2].to(ddp_engine.device)
			lose_prompt_masks = inputs[3].to(ddp_engine.device)
			cot_fill_ids = inputs[4].to(ddp_engine.device)
			cot_fill_masks = inputs[5].to(ddp_engine.device)
			end_think_ids = inputs[6].to(ddp_engine.device)
			end_think_mask = inputs[7].to(ddp_engine.device)
			torch.cuda.empty_cache()

			loss = ddp_engine.forward_diffclf(win_prompt_ids, win_prompt_masks, lose_prompt_ids, lose_prompt_masks, \
												cot_fill_ids, cot_fill_masks, end_think_ids, end_think_mask, \
													time_step=args.time_step)
			ddp_engine.backward(loss)
			ddp_engine.step()
			torch.cuda.empty_cache()

			loss_list.append(loss.detach().cpu().item())
			torch.cuda.empty_cache()
			if args.local_rank== 0 and idx % args.n_accumulation_steps == 0:
				lrs = list(ddp_engine.get_lr())
				print(f"[Batch-{idx}/{len(ddp_train_loader)}]--Training Learning Rate: {lrs}", flush=True)
				print(f"[Batch-{idx}/{len(ddp_train_loader)}]--Training Diff Loss: {sum(loss_list) / len(loss_list)}", flush=True)
				
			if (idx % args.save_step == 0 and idx !=0) or idx == len(ddp_train_loader)-1:
				if args.local_rank==0:
					print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
					print(f"<<Save Step Log...>>")
					print(f"[Batch-{idx}/{len(ddp_train_loader)}]--AVG Training Diff Loss: {sum(loss_list) / len(loss_list)}", flush=True)
					print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
				loss_list = []
				torch.cuda.empty_cache()

				tag = f"checkpoint_{epoch_id * len(ddp_train_loader) + idx}"
				ddp_engine.save_checkpoint(save_dir, tag=tag)
				if args.local_rank == 0:
					print(f"Save checkpoint into {save_dir}, Tag: {tag}", flush=True)
				torch.cuda.empty_cache()
				dist.barrier()
			
