import sys
sys.path.append("../")
import os
import random
import argparse
import time
from datetime import datetime
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoConfig
from liger_kernel.transformers import AutoLigerKernelForCausalLM
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from peft import LoraConfig, get_peft_model
import pickle
from copy import deepcopy
import numpy as np

from CCoT_Finetune.main_model import SynAdapt
from utils import set_seed, load_jsonl, save_jsonl, construct_prompt, is_multi_choice
from parser import *
from trajectory import *
from data_loader import load_data
from evaluate import evaluate


@torch.no_grad()
def generate_completions(model, raw_model, tokenizer, prompts, reThink_prompts, device, 
					batch_size=1, steps=4, cot_gen_length=512, content_gen_length=12000, do_sample=False, disable_tqdm=False, log_step=4, thres=0.5):
	end_think_mark = "</think>"
	cot_fill_str = "<T>" * cot_gen_length

	generations = []
	if not disable_tqdm:
		progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions")
	
	is_hard_list = []
	hard_score_list = []
	for i in range(0, len(prompts), batch_size):
		if i % 2 == 0:
			progress.update(2 * batch_size)
		batch_prompts = prompts[i:i+batch_size]
		prompt_inputs = tokenizer.batch_encode_plus(batch_prompts, padding=True, return_tensors='pt', padding_side="left", add_special_tokens=False)
		prompt_ids, prompt_masks = prompt_inputs['input_ids'].to(device), prompt_inputs['attention_mask'].to(device)

		batch_reThink_prompts = reThink_prompts[i:i+batch_size]
		reThink_prompt_inputs = tokenizer.batch_encode_plus(batch_reThink_prompts, padding=True, return_tensors='pt', padding_side="left", add_special_tokens=False)
		reThink_prompt_ids, reThink_prompt_masks = reThink_prompt_inputs['input_ids'].to(device), reThink_prompt_inputs['attention_mask'].to(device)

		cot_fill_str_list = [cot_fill_str for sample in batch_prompts]
		cot_fill_inputs = tokenizer.batch_encode_plus(cot_fill_str_list, padding=True, return_tensors='pt', padding_side="right", add_special_tokens=False)
		cot_fill_ids, cot_fill_mask = cot_fill_inputs['input_ids'].to(device), cot_fill_inputs['attention_mask'].to(device)

		end_think_mark_list = [end_think_mark for sample in batch_prompts]
		end_think_inputs = tokenizer.batch_encode_plus(end_think_mark_list, padding=True, return_tensors='pt', padding_side="right", add_special_tokens=False)
		end_think_ids, end_think_mask = end_think_inputs['input_ids'].to(device), end_think_inputs['attention_mask'].to(device)

		difficulty_score = model.judge_score(prompt_ids, prompt_masks,
								cot_fill_ids, cot_fill_mask, end_think_ids, end_think_mask, time_step=4)

		print(f"[Difficulty]: {difficulty_score.detach().cpu().item()}")
		hard_score_list.append(difficulty_score.detach().cpu().item())
		if difficulty_score.detach().cpu().item() > thres:
			is_hard_list.append(True)
			print(f">>>>> Hard Question Find!!!")
			outputs = raw_model.generate(
				inputs_ids = reThink_prompt_ids, attention_mask = reThink_prompt_masks,
				max_new_tokens = content_gen_length,
				do_sample=do_sample,
			)
			output_text = tokenizer.batch_decode(outputs[:, :], skip_special_tokens=True)
		else:
			is_hard_list.append(False)
			outputs = model.generate(
				prompt_ids, prompt_masks,
				cot_fill_ids, cot_fill_mask,
				end_think_ids, end_think_mask,
				time_step = steps,
				max_new_tokens = content_gen_length,
				do_sample=do_sample
			)
			output_text = tokenizer.batch_decode(outputs[:, :], skip_special_tokens=True)
		if i % log_step == 0:
			print("============================")
			print(output_text[0], flush = True)
			print("============================")
		generations += output_text
	return generations, is_hard_list, hard_score_list


def get_output_info(args, data_name):
	out_prefix = f"{args.split}_{args.state_dict_tag}_dataCnt{args.num_test_sample}_seed{args.seed}_MaxG{args.max_tokens_per_call}"
	return out_prefix, f"{args.output_dir}/{data_name}/{out_prefix}_sample{args.do_sample}_genCnt{args.n_sampling}_Tstep{args.time_step}_ccot{args.cot_len}.jsonl"


def prepare_data(data_name, args):
	examples = load_data(data_name, args.split, args.data_dir)
	# sample `num_test_sample` from dataset
	if args.num_test_sample > 0:
		# examples = random.sample(examples, min(args.num_test_sample, len(examples)))
		examples = examples[: args.num_test_sample]
	# shuffle
	if args.shuffle:
		random.seed(datetime.now().timestamp())
		random.shuffle(examples)
	# select start and end
	examples = examples[args.start : len(examples) if args.end == -1 else args.end]
	# get out_file name
	out_file_prefix, out_file = get_output_info(args, data_name)
	output_dir = args.output_dir
	os.makedirs(f"{output_dir}/{data_name}", exist_ok=True)
	# load all processed samples
	processed_samples = []
	if not args.overwrite:
		processed_files = [
			f
			for f in os.listdir(f"{output_dir}/{data_name}/")
			if f.endswith(".jsonl") and f.startswith(out_file_prefix)
		]
		for f in processed_files:
			processed_samples.extend(
				list(load_jsonl(f"{output_dir}/{data_name}/{f}"))
			)
	# dedepulicate
	processed_samples = {sample["idx"]: sample for sample in processed_samples}
	processed_idxs = list(processed_samples.keys())
	processed_samples = list(processed_samples.values())
	examples = [example for example in examples if example["idx"] not in processed_idxs]
	return examples, processed_samples, out_file


def main(synAdapt, raw_llm, tokenizer, device, data_name, args):
	examples, processed_samples, out_file = prepare_data(data_name, args)
	print("=" * 50)
	print("data:", data_name, " ,remain samples:", len(examples))
	if len(examples) > 0:
		print(examples[0])
	
	samples = []
	for example in tqdm(examples, total=len(examples), desc="construct examples"):
		idx = example["idx"]
		example["question"] = parse_question(example, data_name)
		if example["question"] == "":
			continue
		gt_cot, gt_ans = parse_ground_truth(example, data_name)
		example["gt_ans"] = gt_ans
		full_prompt = construct_prompt(example, data_name, args.prompt_type, args)
		reThink_prompt = construct_prompt(example, data_name, args.reThink_prompt_type, args)
		sample = {
			"idx": idx,
			"question": example["question"],
			"gt_cot": gt_cot,
			"gt": gt_ans,
			"prompt": full_prompt,
			"reThink_prompt": reThink_prompt
		}
		# add remain fields
		for key in [
			"level",
			"type",
			"unit",
			"solution_type",
			"choices",
			"solution",
			"ques_type",
			"ans_type",
			"answer_type",
			"dataset",
			"subfield",
			"filed",
			"theorem",
			"answer",
		]:
			if key in example:
				sample[key] = example[key]
		samples.append(sample)

	start_time = time.time()
	# repeat n times
	prompts = [
		sample["prompt"] for sample in samples for _ in range(args.n_sampling)
	]
	reThink_prompts = [
		sample["reThink_prompt"] for sample in samples for _ in range(args.n_sampling)
	]
	#* Main Generation
	outputs_text, is_hard_list, hard_score_list = generate_completions(
			model=synAdapt,
			raw_model=raw_llm,
			tokenizer=tokenizer,
			prompts=prompts,
			reThink_prompts=reThink_prompts,
			device = device,
			content_gen_length=args.max_tokens_per_call,
			do_sample = args.do_sample,
			batch_size = args.batch_size,
			log_step= args.log_step,
			thres = args.thres,
			steps = args.time_step,
			cot_gen_length = args.cot_len
	)
	# 对于非 vllm 模型，finish_reason 暂时设为 None
	outputs = [(text, len(tokenizer.tokenize(text)), None) for text in outputs_text]

	# 添加stop词
	stop_words = ["</s>", "<|im_end|>", "<|endoftext|>"]
	stop_words.extend(["assistant", "user", "_end", "_start"])
	stop_words.extend(["\nProblem", "User:", "Assistant:", "</answer>", "</s>"])

	# process all outputs
	question_ans_list = []
	for (query), (output, output_len, finish_reason) in zip(prompts, outputs):
		output = output.rstrip()
		question_ans_list.append((query, output, output_len, finish_reason))

	# remove input_prompt from end_prompt
	codes = []
	finish_reasons = []
	output_lens = []
	for i in range(len(question_ans_list)):
		question, output, output_len, finish_reason = question_ans_list[i]
		code = output.split(question)[-1].strip()
		for stop_word in stop_words:
			if stop_word in code:
				code = code.split(stop_word)[0].strip()
		codes.append(code)
		output_lens.append(output_len)
		finish_reasons.append(finish_reason)

	# extract preds
	results = [
		run_execute(None, code, args.prompt_type, data_name) for code in codes
	]
	time_use = time.time() - start_time

	# put results back to examples
	all_samples = []
	for i, sample in enumerate(samples):
		code = codes[i * args.n_sampling : (i + 1) * args.n_sampling]
		output_len = output_lens[i * args.n_sampling : (i + 1) * args.n_sampling]
		result = results[i * args.n_sampling : (i + 1) * args.n_sampling]
		preds = [item[0] for item in result]
		reports = [item[1] for item in result]
		finish_reason_list = finish_reasons[i * args.n_sampling : (i + 1) * args.n_sampling]
		for j in range(len(preds)):
			if sample["gt"] in ["A", "B", "C", "D", "E"] and preds[j] not in [
				"A",
				"B",
				"C",
				"D",
				"E",
			]:
				preds[j] = choice_answer_clean(code[j])
			elif is_multi_choice(sample["gt"]) and not is_multi_choice(preds[j]):
				# remove any non-choice char
				preds[j] = "".join(
					[c for c in preds[j] if c in ["A", "B", "C", "D", "E"]]
				)
		sample.pop("prompt")
		sample.update({"code": code, "output_lens": output_len, "pred": preds, "report": reports, "finish_reason": finish_reason_list, "is_hard": is_hard_list[i],
		"hard_score": hard_score_list[i]})
		all_samples.append(sample)

	# add processed samples
	all_samples.extend(processed_samples)
	all_samples, result_json = evaluate(
		samples=all_samples,
		data_name=data_name,
		prompt_type=args.prompt_type,
		execute=True,
	)

	# save outputs
	if len(processed_samples) < len(all_samples) and args.save_outputs:
		save_jsonl(all_samples, out_file)

	result_json['hard_ratio'] = np.mean([item['is_hard'] for item in all_samples])
	result_json['hard_mean'] = np.mean([item['hard_score'] for item in all_samples])
	result_json["time_use_in_second"] = time_use
	result_json["time_use_in_minite"] = (
		f"{int(time_use // 60)}:{int(time_use % 60):02d}"
	)

	with open(
		out_file.replace(".jsonl", f"_metrics.json"), "w"
	) as f:
		json.dump(result_json, f, indent=4)
	return result_json


def setup(args):
	# load model
	data_list = args.data_names.split(",")
	need_eval_data_list = []
	if not args.overwrite:
		for data_name in data_list:
			out_prefix, out_file = get_output_info(args, data_name)
			out_metric_json = out_file.replace(".jsonl", f"_metrics.json")
			if os.path.exists(out_metric_json):
				print(f"Skipping {data_name} because {out_metric_json} already exists.")
				continue
			else:
				need_eval_data_list.append(data_name)
		if len(need_eval_data_list) == 0:
			print("All datasets already evaluated. Exiting.")
			exit(0)
		data_list = need_eval_data_list
	
	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	qwen_config = AutoConfig.from_pretrained(args.Qwen_model_path, trust_remote_code=True,)
	qwen_config._attn_implementation = "flash_attention_2"
	qwen_model = AutoLigerKernelForCausalLM.from_pretrained(args.Qwen_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, config=qwen_config).to(device)
	qwen_tokenizer = AutoTokenizer.from_pretrained(args.Qwen_model_path, trust_remote_code=True)
	raw_qwen_model = deepcopy(qwen_model)

	lora_config = LoraConfig(
		r=8,
		lora_alpha=32,
		target_modules=["q_proj", "v_proj"],
		lora_dropout=0.1,
		bias="none",
		task_type="CAUSAL_LM"
	)
	qwen_lora_model = get_peft_model(qwen_model, lora_config, adapter_name="lora_ccot")
	synAdapt = SynAdapt(qwen_lora_model, qwen_tokenizer)
	combined_model = get_fp32_state_dict_from_zero_checkpoint(args.state_dict_dir, tag = args.state_dict_tag)
	missing_keys, unexpected_keys = synAdapt.load_state_dict(combined_model, strict=False)
	if unexpected_keys:
		raise Exception("⚠️ 忽略无关参数:", unexpected_keys)

	# infer & eval
	results = []
	for data_name in data_list:
		results.append(
				main(
					synAdapt, raw_qwen_model, qwen_tokenizer, 
					device, data_name, args)
				)

	# add "avg" result to data_list and results
	data_list.append("avg")
	results.append(
		{
			"acc": sum([result["acc"] for result in results]) / len(results),
			"generation_len": sum([result["generation_len"] for result in results]) / len(results)
		}
	)
	# print all results
	pad = max([len(data_name) for data_name in data_list])
	print("\t".join(data_name.ljust(pad, " ") for data_name in data_list))
	print("\t".join([f"{result['acc']:.1f}".ljust(pad, " ") for result in results]))
	print("\t".join([f"{result['generation_len']:.1f}".ljust(pad, " ") for result in results]))


def parse_args():
	def str2bool(v):
		if v in ['True', 'true', 'TRUE']:
			return True
		else:
			return False
	set_seed(42)
	parser = argparse.ArgumentParser()
	parser.add_argument("--data_names", default="gsm8k,math", type=str)
	parser.add_argument("--data_dir", default="./data", type=str)
	parser.add_argument("--state_dict_dir", default="gpt-4", type=str)
	parser.add_argument("--state_dict_tag", default="gpt-4", type=str)
	parser.add_argument("--Qwen_model_path", default="gpt-4", type=str)

	parser.add_argument("--batch_size", default=1, type=int)
	parser.add_argument("--log_step", default=4, type=int)
	parser.add_argument("--time_step", type=int, default=32)
	parser.add_argument("--cot_len", type=int, default=512)
	parser.add_argument("--thres", type=float, default=1.0)


	parser.add_argument("--output_dir", default="./output", type=str)
	parser.add_argument("--prompt_type", default="qwen_deepseek_distill", type=str)
	parser.add_argument("--reThink_prompt_type", default="qwen_deepseek_distill_reThink", type=str)
	parser.add_argument("--split", default="test", type=str)
	parser.add_argument("--num_test_sample", default=-1, type=int)  # -1 for full data
	parser.add_argument("--seed", default=0, type=int)
	parser.add_argument("--start", default=0, type=int)
	parser.add_argument("--end", default=-1, type=int)
	parser.add_argument("--temperature", default=0, type=float)
	parser.add_argument("--n_sampling", default=1, type=int)
	parser.add_argument("--top_p", default=1, type=float)
	parser.add_argument("--max_tokens_per_call", default=2048, type=int)
	parser.add_argument("--max_model_len", default=4096, type=int)
	parser.add_argument("--shuffle", action="store_true")
	parser.add_argument("--save_outputs", action="store_true")
	parser.add_argument("--overwrite", action="store_true")
	parser.add_argument("--use_safetensors", action="store_true")
	parser.add_argument("--do_sample", default="False", type=str2bool)
	parser.add_argument("--num_shots", type=int, default=0)
	parser.add_argument(
		"--apply_chat_template",
		action="store_true",
		help="Apply chat template to prompt.",
	)
	parser.add_argument("--pipeline_parallel_size", type=int, default=1)
	parser.add_argument(
		"--adapt_few_shot",
		action="store_true",
		help="Few shot for multiple-choice questions, zero shot for others.",
	)
	args = parser.parse_args()
	args.top_p = (
		1 if args.temperature == 0 else args.top_p
	)  # top_p must be 1 when using greedy sampling (vllm)
	print(args)
	return args



if __name__ == "__main__":
	args = parse_args()
	set_seed(args.seed)
	setup(args)
