import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn

from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
import os
import random
from torch.utils.data import Dataset
import json
from tqdm import tqdm
import pickle

prompt_mapping = {
	"qwen": (
		"<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{{}}.<|im_end|>\n"
		"<|im_start|>user\n{input}<|im_end|>\n"
		"<|im_start|>assistant\n"
	),
	"qwen_deepseek_distill": (
		"<｜begin▁of▁sentence｜>Please reason step by step, and put your final answer within \\boxed{{}}"
		"<｜User｜>{input}"
		"<｜Assistant｜><think>",   
		"{COT}</think>",
		"{output}<｜end▁of▁sentence｜>"
	),
	"qwen_deepseek_distill_reThink": (
		"<｜begin▁of▁sentence｜>Think step by step, but only keep minimum draft for each thinking step, with 5 words at most.\nReturn the answer at the end of the response within \\boxed{{}}"
		"<｜User｜>{input}"
		"<｜Assistant｜><think>",   
		"{COT}</think>",
		"{output}<｜end▁of▁sentence｜>"
	)
}





def post_process_prefix(prefix_list, target_num):
	"""从 prefix_list 均匀选取 target_num 个元素"""
	if target_num >= len(prefix_list):
		return prefix_list  # 如果目标数量 >= 列表长度，直接返回原列表
	indices = np.linspace(0, len(prefix_list) - 1, num=target_num, dtype=int)
	return [prefix_list[i] for i in indices]


class CotDataset(Dataset):
	def __init__(self, tokenizer, index_data_path, raw_data_path, target_ccot_dir,
					cot_length=512, time_step=4,
						test_gpu=False, shuffle= True):
		self.index_data_path = index_data_path
		self.raw_data_path = raw_data_path
		self.target_ccot_dir = target_ccot_dir
		self.tokenizer = tokenizer
		self.cot_length = cot_length
		self.time_step = time_step

		# load raw data
		raw_data_list = []
		self.rawId2data = {}
		with open(raw_data_path, 'r') as f:
			for line in f:
				raw_data_list.append(json.loads(line))
				self.rawId2data[raw_data_list[-1]['Raw_data_idx']] = raw_data_list[-1]
		print(f"[RAW Data Length]: {len(raw_data_list)}")
		del raw_data_list

		# load index data
		self.index_data_list = []
		with open(index_data_path, 'r') as f:
			for line in f:
				self.index_data_list.append(json.loads(line))
		print(f"[Index Data Length]: {len(self.index_data_list)}")
		self.shuffle = shuffle
		if self.shuffle:
			random.shuffle(self.index_data_list)

	def __len__(self):
		return len(self.index_data_list) 

	def __getitem__(self, idx):
		raw_data_idx = self.index_data_list[idx]['raw_data_idx']
		raw_data = self.rawId2data[raw_data_idx]
		prompt = prompt_mapping['qwen_deepseek_distill'][0].format(input = raw_data['Question'])
		prompt_reThink = prompt_mapping['qwen_deepseek_distill_reThink'][0].format(input = raw_data['Question'])
		answer_content = prompt_mapping['qwen_deepseek_distill'][2].format(output = raw_data['Answer_Content'])
		gt_answer = raw_data['GT_Answer']

		target_ccot_path = self.index_data_list[idx]['trace_path']
		with open(target_ccot_path, 'rb') as file:
			target_ccot = pickle.load(file)
		if len(target_ccot) > self.time_step:
			target_ccot = post_process_prefix(target_ccot, self.time_step)
		target_ccot = [item.cpu() for item in target_ccot]
		target_dcot = "".join(raw_data['Split_COT_Content'])

		cot_fill_str = "<T>" * self.cot_length
		end_think_mark = "</think>"
		difficulty = raw_data['Difficulty'] / 10.0
		topic = raw_data['Topic']
		total_length = raw_data['Total_Length']

		return {
			'prompt': prompt,
			'prompt_reThink': prompt_reThink,
			'answer_content': answer_content,
			'gt_answer': gt_answer,
			'difficulty': difficulty,
			'topic': topic,

			"target_ccot": target_ccot,
			"target_dcot": target_dcot,

			"cot_fill_str": cot_fill_str,
			"end_think_mark": end_think_mark,
			"total_length": total_length
		}


def collate_fn(batch, tokenizer):
	"""
	将字典数据的每个字段直接合并成 list，不做 padding 或 stacking。
	"""
	question_list = [sample['prompt'] for sample in batch]
	question_inputs = tokenizer.batch_encode_plus(question_list, padding=True, return_tensors='pt', padding_side="left")
	question_ids, question_masks = question_inputs['input_ids'], question_inputs['attention_mask']
	question_list_reThink = [sample['prompt_reThink'] for sample in batch]
	question_inputs_reThink = tokenizer.batch_encode_plus(question_list_reThink, padding=True, return_tensors='pt', padding_side="left")
	question_ids_reThink, question_masks_reThink = question_inputs_reThink['input_ids'], question_inputs_reThink['attention_mask']

	cot_fill_str_list = [sample['cot_fill_str'] for sample in batch]
	cot_fill_inputs = tokenizer.batch_encode_plus(cot_fill_str_list, padding=True, return_tensors='pt', padding_side="right", add_special_tokens=False)
	cot_fill_ids, cot_fill_mask = cot_fill_inputs['input_ids'], cot_fill_inputs['attention_mask']
	target_ccot_list = [sample['target_ccot'] for sample in batch]
	target_dcot_list = [sample['target_dcot'] for sample in batch]
	target_dcot_inputs = tokenizer.batch_encode_plus(target_dcot_list, padding=True, return_tensors='pt', padding_side="right", add_special_tokens=False)
	target_dcot_ids, target_dcot_mask = target_dcot_inputs['input_ids'], target_dcot_inputs['attention_mask']
	end_think_mark_list = [sample['end_think_mark'] for sample in batch]
	end_think_inputs = tokenizer.batch_encode_plus(end_think_mark_list, padding=True, return_tensors='pt', padding_side="right", add_special_tokens=False)
	end_think_ids, end_think_mask = end_think_inputs['input_ids'], end_think_inputs['attention_mask']

	answer_list = [sample['answer_content'] for sample in batch]
	answer_contents = tokenizer.batch_encode_plus(answer_list, padding=True, return_tensors='pt', padding_side="right", add_special_tokens=False)
	answer_ids, answer_mask = answer_contents['input_ids'], answer_contents['attention_mask']

	avg_len = sum([sample['total_length'] for sample in batch]) / len(batch)
	max_len = max([sample['total_length'] for sample in batch])
	difficulty_labels = torch.tensor([sample['difficulty'] for sample in batch])

	return (question_ids, question_masks, question_ids_reThink, question_masks_reThink, \
		target_dcot_ids, target_dcot_mask, target_ccot_list, cot_fill_ids, cot_fill_mask, end_think_ids, end_think_mask, \
			answer_ids, answer_mask, difficulty_labels), \
			(question_list, answer_list, avg_len, max_len)





class CotDataset_hardness(Dataset):
	def __init__(self, tokenizer, binary_path, cot_length=512, shuffle= True):
		self.tokenizer = tokenizer
		
		# load binary data
		self.binary_index_list = []
		with open(binary_path, 'r') as file:
			for line in file.readlines():
				self.binary_index_list.append(json.loads(line))
		
		if self.shuffle:
			random.shuffle(self.binary_index_list)

	def __len__(self):
		return len(self.binary_index_list) 

	def __getitem__(self, idx):
		binary_data = self.binary_index_list[idx]

		win_question = binary_data['win_question']
		win_prompt = prompt_mapping['qwen_deepseek_distill_split'][0].format(input = win_question)

		lose_question = binary_data['win_question']
		lose_prompt = prompt_mapping['qwen_deepseek_distill_split'][0].format(input = lose_question)

		end_think_mark = "</think>"
		cot_fill_str = "<T>" * self.cot_length

		return {
			'win_question': win_question,
			'win_prompt': win_prompt,
			'lose_question': lose_question,
			"lose_prompt": lose_prompt,

			"end_think_mark": end_think_mark,
			"cot_fill_str": cot_fill_str
		}


def collate_fn_hardness(batch, tokenizer):
	win_questions = [sample['win_question'] for sample in batch]
	win_prompt = [sample['win_prompt'] for sample in batch]
	win_prompt_inputs = tokenizer.batch_encode_plus(win_prompt, padding=True, return_tensors='pt', padding_side="left")
	win_prompt_ids, win_prompt_masks = win_prompt_inputs['input_ids'], win_prompt_inputs['attention_mask']
	lose_questions = [sample['lose_question'] for sample in batch]
	lose_prompt = [sample['lose_prompt'] for sample in batch]
	lose_prompt_inputs = tokenizer.batch_encode_plus(lose_prompt, padding=True, return_tensors='pt', padding_side="left")
	lose_prompt_ids, lose_prompt_masks = lose_prompt_inputs['input_ids'], lose_prompt_inputs['attention_mask']

	end_think_mark_list = [sample['end_think_mark'] for sample in batch]
	end_think_inputs = tokenizer.batch_encode_plus(end_think_mark_list, padding=True, return_tensors='pt', padding_side="right", add_special_tokens=False)
	end_think_ids, end_think_mask = end_think_inputs['input_ids'], end_think_inputs['attention_mask']
	cot_fill_str_list = [sample['cot_fill_str'] for sample in batch]
	cot_fill_inputs = tokenizer.batch_encode_plus(cot_fill_str_list, padding=True, return_tensors='pt', padding_side="right", add_special_tokens=False)
	cot_fill_ids, cot_fill_mask = cot_fill_inputs['input_ids'], cot_fill_inputs['attention_mask']

	return (win_prompt_ids, win_prompt_masks, \
		lose_prompt_ids, lose_prompt_masks, \
		cot_fill_ids, cot_fill_mask, end_think_ids, end_think_mask), \
		(win_questions, lose_questions)