import torch
from transformers import AutoTokenizer,  AutoConfig
from liger_kernel.transformers import AutoLigerKernelForCausalLM
import json
import torch.nn as nn
import torch.nn.functional as F
import pickle
import argparse
import numpy as np
import os
import threading
import copy
import time

from utils import prompt_mapping, ensure_add_token, get_precise_equidistant_points, split_list


def thread_task(thread_idx, args, LLM, tokenizer, device, all_data_list, output_path, output_dir, threshold=0.1,  verbose=False):
	print(f"========Thread {thread_idx} started. Num of data: {len(all_data_list)}")

	torch.cuda.set_device(device)
	LLM.to(device)
	
	cot_str = "<T>" * args.ccot_length
	soft_cot_ids = tokenizer(cot_str, return_tensors="pt", add_special_tokens=False).input_ids.to(LLM.device)

	begin_time = time.time()
	Total_step = args.max_step
	with open(output_path, 'w') as out_f:
		for data_idx, data_item in enumerate(all_data_list):
			try:
				if data_idx % 2 == 0 and thread_idx == 0 and data_idx != 0:
					print(f"Thread {thread_idx} processed {data_idx}/{len(all_data_list)} data items.")
					time_cost = time.time()-begin_time
					avg_time_cost = time_cost / data_idx
					print(f"Current Time Cost: {time.time()-begin_time}(s), Avg time Cost: {avg_time_cost}(s)", flush=True)

				question = data_item["Question"]
				prompts = prompt_mapping['qwen_deepseek_distill'][0].format(input= question)
				answer = data_item['Answer_Content'] + '<｜end▁of▁sentence｜>'
				cot_content = "".join(data_item['Split_COT_Content'])
				cot_content_endT = cot_content + '</think>'
				endThink_str = '</think>'
				raw_data_idx = data_item['Raw_data_idx']

				soft_cot_embedding = LLM.get_input_embeddings()(soft_cot_ids)
				soft_cot_embedding_parameters = nn.Parameter(soft_cot_embedding)
				optimizer = torch.optim.AdamW([soft_cot_embedding_parameters], lr=1e-3) #1e-3

				ccot_trace_list = []
				prompts_ids = tokenizer(prompts, return_tensors="pt").input_ids.to(device)
				prompts_embeddings = LLM.get_input_embeddings()(prompts_ids)
				answer_ids = tokenizer(answer, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
				answer_embeddings = LLM.get_input_embeddings()(answer_ids)
				cot_content_endT_ids = tokenizer(cot_content_endT, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
				endT_ids = tokenizer(endThink_str, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
				entT_embeddings = LLM.get_input_embeddings()(endT_ids)
				Q_thinkCot_endThink_Ans_labels = torch.concat(
					(torch.full((prompts_ids.shape[0], prompts_ids.shape[1] + soft_cot_embedding.shape[1] + endT_ids.shape[1]), -100).to(device),
					answer_ids),
					dim=1
				)

				# Obtain target hidden state
				layers_to_save = [item for item in LLM.model.layers] 
				with torch.no_grad():
					last_token_features_dcot = {}
					def hook_fn(module, input, output):
						last_token_features_dcot[module] = output[0][:, -1, :]
					handles = [layer.register_forward_hook(hook_fn) for layer in layers_to_save]
					LLM.forward(
						input_ids = torch.cat([prompts_ids, cot_content_endT_ids], dim=1),
					)
					for handle in handles:
						handle.remove()

				for step in range(Total_step):
					last_token_features_ccot = {}
					def hook_fn(module, input, output):
						last_token_features_ccot[module] = output[0][:, -answer_embeddings.shape[1]-1, :]  # 取最后一个 token 的特征
					handles = [layer.register_forward_hook(hook_fn) for layer in layers_to_save]
					outputs = LLM.forward(
						inputs_embeds = torch.cat([prompts_embeddings, soft_cot_embedding_parameters, entT_embeddings, answer_embeddings], dim=1),
						labels = Q_thinkCot_endThink_Ans_labels
					)
					for handle in handles:
						handle.remove()
					
					#!CE loss
					ce_loss = outputs.loss

					#! l1 Loss
					l1_losses = []
					layers = list(last_token_features_ccot.keys())
					for i in range(len(layers)):
						target = last_token_features_dcot[layers[i]]
						pred = last_token_features_ccot[layers[i]]
						loss = F.smooth_l1_loss(pred, target)
						l1_losses.append(loss)
					l1_loss = torch.mean(torch.stack(l1_losses))  # 取平均损失
				
					all_loss = ce_loss + l1_loss
					if verbose:
						print("===========================")
					all_loss.backward()
					optimizer.step()
					optimizer.zero_grad()
					if verbose:
						print(f"Step-{step}, Loss: {ce_loss.item()}, L1_loss: {l1_loss.item()}")
					ccot_trace_list.append(
						soft_cot_embedding_parameters.data.clone().detach().cpu()
					)
				
					if ce_loss.item() < threshold:
						break
					torch.cuda.empty_cache()

				if verbose:
					#! test Generate-1
					gen_input_embeds = torch.cat([prompts_embeddings, soft_cot_embedding_parameters, entT_embeddings], dim=1)
					gen_outputs = LLM.generate(
						inputs_embeds = gen_input_embeds,
						do_sample=False, max_new_tokens=4096
					)
					generate_str = tokenizer.batch_decode(gen_outputs, skip_special_tokens=True)[0]
					print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
					print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
					print(f"<<GT Answer>>:\n {data_item['GT_Answer']}")
					print(f"<<Answer>>:\n {data_item['Answer_Content']}")
					print(f"<<Question Generate Str>>:\n {generate_str}")


				ccot_trace_list = get_precise_equidistant_points(ccot_trace_list, args.loop_step-1)
				embedding_path = os.path.join(output_dir, f"{raw_data_idx}.pkl")
				with open(embedding_path, "wb") as embed_out_f:
					pickle.dump(ccot_trace_list, embed_out_f)
				out_f.write(json.dumps(
					{
					"raw_data_idx": raw_data_idx,
					"trace_path": embedding_path,
					"final_ce_loss": ce_loss.item(),
					"final_l1_loss": l1_loss.item(),
					"trace_len": len(ccot_trace_list)
					}
				) + '\n')
			except Exception as e:
				print(f"Encounter Exception when process {data_idx}, Exception is {e}")

	print(f"========Thread {thread_idx} end. [Output Embedding Trace Path]: {output_path}")



def main(args):
	data_path = args.data_path
	all_data_list = []
	with open(data_path, 'r') as file:
		for line in file.readlines():
			all_data_list.append(json.loads(line))

	# scan for exist CCoT
	if args.exist_index_path != "":
		exist_raw_idx = []
		with open(args.exist_index_path, 'r') as file:
			for line in file.readlines():
				exist_raw_idx.append(json.loads(line)['raw_data_idx'])
			
		filter_all_data_list = []
		for item in all_data_list:
			if item['raw_data_idx'] not in exist_raw_idx:
				filter_all_data_list.append(item)
		all_data_list = filter_all_data_list


	output_dir = args.output_embedding_trace_dir
	if not os.path.exists(output_dir):
		os.makedirs(output_dir)

	embedding_output_dir = os.path.join(output_dir, "embedding_trace")
	if not os.path.exists(embedding_output_dir):
		os.makedirs(embedding_output_dir)

	all_thread_data_list = split_list(all_data_list, args.n_thread)


	# 7b
	model_path = args.model_path
	config = AutoConfig.from_pretrained(model_path, trust_remote_code=True,)
	config._attn_implementation = "flash_attention_2"
	LLM = AutoLigerKernelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, config=config)
	tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
	# freeze LLM
	for param in LLM.parameters():
		param.requires_grad = False
	ensure_add_token(tokenizer, LLM, "<T>")

	threads = []
	for i in range(args.n_thread):
		t = threading.Thread(target=thread_task, 
		args=(i, args, copy.deepcopy(LLM), tokenizer, torch.device(f"cuda:{i}"), 
		all_thread_data_list[i], 
		os.path.join(output_dir, f"Part_{i}.json"), embedding_output_dir, args.threshold, args.verbose)
		)
		threads.append(t)
		t.start()  # start up thread
	
	# wait for threads ending
	for t in threads:
		t.join()
	print("All threads end up!!!")

if __name__ == '__main__':

	def str2bool(v):
		if v in ['True', 'true', 'TRUE']:
			return True
		else:
			return False
	
	parser = argparse.ArgumentParser(description="Generate CCoT before fine-tuning")
	parser.add_argument('--verbose', type=str2bool, default='False', help="more detailed output")

	parser.add_argument('--model_path', type=str, default='', help="LLM path")

	parser.add_argument('--data_path', type=str, default='False', help= "input data path")
	parser.add_argument('--exist_index_path', type=str, default='', help="existed CCoT index path")

	parser.add_argument('--output_embedding_trace_dir', type=str, default='', help="output CCoT dir")

	parser.add_argument('--max_step', type=int, default=128, help="total step num")
	parser.add_argument('--loop_step', type=int, default=8, help="save step num")
	parser.add_argument('--ccot_length', type=int, default=512)

	parser.add_argument('--n_thread', type=int, default=2)
	parser.add_argument('--threshold', type=float, default=0.1)

	args = parser.parse_args()
	main(args)