import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.functional as F



def ensure_add_token(tokenizer, llm, token="<T>"):
	if "<T>" not in tokenizer.get_vocab():
		tokenizer.add_special_tokens({"additional_special_tokens": [token]})
		llm.resize_token_embeddings(len(tokenizer))
	else:
		print(f"{token} already in tokenizer")


class SynAdapt(nn.Module):
	"""
	The main LLM model of SynAdapt
	"""

	def __init__(self, qwen_lora_model, qwen_tokenizer):
		super().__init__()
		ensure_add_token(qwen_tokenizer, qwen_lora_model, "<T>")
		ensure_add_token(qwen_tokenizer, qwen_lora_model, "</think>")
		self.qwen_lora_model = qwen_lora_model
		self.qwen_tokenizer = qwen_tokenizer
		self.gradient_checkpointing_enable = self.qwen_lora_model.gradient_checkpointing_enable
		
		# Difficulty Classifier
		self.hidden_size = qwen_lora_model.config.hidden_size
		self.diffculty_classifier = nn.Sequential(
			nn.Linear(self.hidden_size, self.hidden_size // 2),
			nn.ReLU(),
			nn.Dropout(0.1),
			nn.Linear(self.hidden_size // 2, 1)  # Output single score for difficulty
		)


	def compute_l1_loss(self, features_generate, features_target):
		l1_losses = []
		for layer_name in features_generate.keys():
			target = features_target[layer_name]
			pred = features_generate[layer_name]
			loss = F.smooth_l1_loss(pred, target)
			l1_losses.append(loss)
		return l1_losses

	def compute_step_loss(step_ccot, target_ccot):
		loss_fn = nn.L1Loss()
		return loss_fn(step_ccot, target_ccot)

	def forward(self, question_ids, question_masks,
					answer_ids, answer_masks,
					target_dcot_ids, target_dcot_mask, target_ccot_list,
					cot_fill_ids, cot_fill_masks,
					end_think_ids, end_think_masks,
					time_step=4
					):

		# prepare
		question_embeds = self.qwen_lora_model.get_input_embeddings()(question_ids)
		end_think_embeds = self.qwen_lora_model.get_input_embeddings()(end_think_ids)
		cot_fill_embeds = self.qwen_lora_model.get_input_embeddings()(cot_fill_ids)
		answer_embeds = self.qwen_lora_model.get_input_embeddings()(answer_ids)
		cot_fill_len = cot_fill_embeds.shape[1]
		layers_to_save = [item for item in self.qwen_lora_model.base_model.model.model.layers]  # 示例：第0、3、6层

		# Obtain target last hidden state
		layers_to_save = [item for item in self.qwen_lora_model.base_model.model.model.layers]  # 示例：第0、3、6层
		with torch.no_grad():
			last_token_features_dcot = {}
			def hook_fn(module, input, output):
				last_token_features_dcot[module] = output[0][:, -1, :]  # 取最后一个 token 的特征
			handles = [layer.register_forward_hook(hook_fn) for layer in layers_to_save]
			with self.qwen_lora_model.disable_adapter():
				self.qwen_lora_model.forward(
					input_ids = torch.cat([question_ids, target_dcot_ids, end_think_ids], dim=1),
					attention_mask = torch.cat([question_masks, target_dcot_mask, end_think_masks], dim=1)
				)
			for handle in handles:
				handle.remove()

		# Alignment loss
		align_losses = []
		self.qwen_lora_model.set_adapter("lora_ccot")
		Q_ccot_embeds = torch.cat([question_embeds, cot_fill_embeds], dim=1)
		Q_ccot_masks = torch.cat([question_masks, cot_fill_masks], dim=1)
		for step_idx in range(time_step):
			loop_outputs = self.qwen_lora_model.base_model.model.model.forward(inputs_embeds=Q_ccot_embeds, attention_mask=Q_ccot_masks)
			Q_ccot_embeds = torch.cat((Q_ccot_embeds[:, :-cot_fill_len], loop_outputs.last_hidden_state[:, -cot_fill_len:]), dim=1)
			if step_idx == (time_step -1):
				align_losses.append(
					F.smooth_l1_loss(Q_ccot_embeds[:, -cot_fill_len:], target_ccot_list[0][step_idx]) * (self.decay_factor ** (time_step-step_idx-1))
				)
		align_loss = torch.mean(torch.stack(align_losses))
		torch.cuda.empty_cache()

		# ans_ce loss
		last_token_features_ccot = {}
		def hook_fn(module, input, output):
			last_token_features_ccot[module] = output[0][:, -answer_ids.shape[1] - 1, :] 
		layers_to_save = [item for item in self.qwen_lora_model.base_model.model.model.layers]
		handles = [layer.register_forward_hook(hook_fn) for layer in layers_to_save]
		with self.qwen_lora_model.disable_adapter():
			qwen_outputs = self.qwen_lora_model.forward(
				inputs_embeds = torch.cat([question_embeds, Q_ccot_embeds[:, -cot_fill_len:], end_think_embeds, answer_embeds], dim=1),
				attention_mask = torch.cat([question_masks, cot_fill_masks, end_think_masks, answer_masks], dim=1),
				labels = torch.concat(
								(torch.full((question_masks.shape[0], question_masks.shape[1] + cot_fill_masks.shape[1] + end_think_masks.shape[1]), -100).to(self.qwen_lora_model.device),
								answer_ids),
								dim=1
						),
				ignore_index = -100
			)
		for handle in handles:
			handle.remove()
		ans_ce_loss = qwen_outputs.loss
		torch.cuda.empty_cache()

		# final_l1 loss
		final_l1_losses = []
		layers = list(last_token_features_dcot.keys())
		for i in range(len(layers)):
			target = last_token_features_dcot[layers[i]]
			pred = last_token_features_ccot[layers[i]]
			loss = F.smooth_l1_loss(pred, target)
			final_l1_losses.append(loss)
		final_l1_loss = torch.mean(torch.stack(final_l1_losses))
		torch.cuda.empty_cache()

		return align_loss, ans_ce_loss, final_l1_loss


	@torch.no_grad()
	def generate(self, question_ids, question_masks,
					cot_fill_ids, cot_fill_masks,
					end_think_ids, end_think_masks, 
							time_step=4,
							do_sample=False, max_new_tokens=2048,
							**gen_args):
		if not do_sample:
			self.qwen_lora_model.eval()
		question_embeds = self.qwen_lora_model.get_input_embeddings()(question_ids)
		end_think_embeds = self.qwen_lora_model.get_input_embeddings()(end_think_ids)
		cot_fill_embeds = self.qwen_lora_model.get_input_embeddings()(cot_fill_ids)

		# Generate CCoT
		self.qwen_lora_model.set_adapter("lora_ccot")
		Q_think_cot_masks = torch.cat([question_masks, cot_fill_masks], dim=1)
		Q_think_cot_embeds = torch.cat([question_embeds, cot_fill_embeds], dim=1)
		for ts in range(time_step):
			outputs = self.qwen_lora_model.base_model.model.model.forward(inputs_embeds=Q_think_cot_embeds, attention_mask=Q_think_cot_masks)
			Q_think_cot_embeds[:, -cot_fill_ids.shape[1]:] = outputs.last_hidden_state[:, -cot_fill_ids.shape[1]:]

		# Directly Generate Answer based on CCoT
		with self.qwen_lora_model.disable_adapter():
			outputs = self.qwen_lora_model.generate(
				inputs_embeds = torch.cat([question_embeds, Q_think_cot_embeds[:, -cot_fill_ids.shape[1]:], end_think_embeds], dim=1),
				attention_mask = torch.cat([question_masks, cot_fill_masks, end_think_masks], dim=1),
				do_sample=do_sample, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_logits=True, **gen_args
			)

		if not do_sample:
			self.qwen_lora_model.train()
		return outputs.sequences[:,:]


	@torch.no_grad()
	def generate_ccot(self, question_ids, question_masks,
					cot_fill_ids, cot_fill_masks,
					end_think_ids, end_think_masks,
							time_step=4,
							do_sample=False, max_new_tokens=2048,
							**gen_args):
		if not do_sample:
			self.qwen_lora_model.eval()
		question_embeds = self.qwen_lora_model.get_input_embeddings()(question_ids)
		cot_fill_embeds = self.qwen_lora_model.get_input_embeddings()(cot_fill_ids)

		# Generate CCoT
		self.qwen_lora_model.set_adapter("lora_ccot")
		Q_think_cot_masks = torch.cat([question_masks, cot_fill_masks], dim=1)
		Q_think_cot_embeds = torch.cat([question_embeds, cot_fill_embeds], dim=1)
		for ts in range(time_step):
			outputs = self.qwen_lora_model.base_model.model.model.forward(inputs_embeds=Q_think_cot_embeds, attention_mask=Q_think_cot_masks)
			Q_think_cot_embeds[:, -cot_fill_ids.shape[1]:] = outputs.last_hidden_state[:, -cot_fill_ids.shape[1]:] #? 2
		return Q_think_cot_embeds[:, -cot_fill_ids.shape[1]:]



	def forward_diffclf(self, win_prompt_ids, win_prompt_masks, \
						lose_prompt_ids, lose_prompt_masks, \
						cot_fill_ids, cot_fill_masks, \
						end_think_ids, end_think_masks,
						time_step=4):
		endThink_embeds = self.qwen_lora_model.get_input_embeddings()(end_think_ids)

		# For win prompt
		win_target_ccots = self.generate_ccot(win_prompt_ids, win_prompt_masks, cot_fill_ids, cot_fill_masks, end_think_ids, end_think_masks, time_step=time_step)
		win_prompt_embeds = self.qwen_lora_model.get_input_embeddings()(win_prompt_ids)
		with self.qwen_lora_model.disable_adapter():
			win_outputs = self.qwen_lora_model.model.model.forward(
					inputs_embeds= torch.cat([win_prompt_embeds, win_target_ccots, endThink_embeds], dim=1)
				)
		win_end_think_output_embeds = win_outputs.last_hidden_state[:, -1]
		win_difficulty_score = self.diffculty_classifier(win_end_think_output_embeds)

		# For lose prompt
		lose_target_ccots = self.generate_ccot(lose_prompt_ids, lose_prompt_masks, cot_fill_ids, cot_fill_masks, end_think_ids, end_think_masks, time_step=time_step)
		lose_prompt_embeds = self.qwen_lora_model.get_input_embeddings()(lose_prompt_ids)
		with self.qwen_lora_model.disable_adapter():
			lose_outputs = self.qwen_lora_model.model.model.forward(
				inputs_embeds= torch.cat([lose_prompt_embeds, lose_target_ccots, endThink_embeds], dim=1)
				)
		lose_end_think_output_embeds = lose_outputs.last_hidden_state[:, -1]
		lose_difficulty_score = self.diffculty_classifier(lose_end_think_output_embeds)

		loss = -torch.log(torch.sigmoid(win_difficulty_score - lose_difficulty_score))
		return loss

	def judge_score(self, win_prompt_ids, win_prompt_masks, \
						cot_fill_ids, cot_fill_masks,
						end_think_ids, end_think_masks,
						time_step=4):
		self.diffculty_classifier.eval()
		with torch.no_grad():
			endThink_embeds = self.qwen_lora_model.get_input_embeddings()(end_think_ids)

			win_target_ccots = self.generate_ccot(win_prompt_ids, win_prompt_masks, cot_fill_ids, cot_fill_masks, end_think_ids, end_think_masks, time_step=time_step)
			win_prompt_embeds = self.qwen_lora_model.get_input_embeddings()(win_prompt_ids)
			with self.qwen_lora_model.disable_adapter():
				win_outputs = self.qwen_lora_model.model.forward(
					inputs_embeds= torch.cat([win_prompt_embeds, win_target_ccots, endThink_embeds], dim=1)
				)
			win_end_think_output_embeds = win_outputs.last_hidden_state[:, -1]
			win_difficulty_score = torch.sigmoid(self.diffculty_classifier(win_end_think_output_embeds))
		return win_difficulty_score