import os
import json
import torch
import argparse
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer


def batched_rewards(conversations, model, tokenizer, batch_size=16, progress_desc="Scoring"):
	"""
	Compute rewards for a list of conversations using QRM model.
	
	Args:
		conversations: List of conversation lists (each conversation is a list of message dicts)
		model: The QRM model
		tokenizer: The tokenizer
		batch_size: Batch size for processing
		progress_desc: Description for progress bar
	
	Returns:
		List of reward scores
	"""
	all_scores = []
	device = next(model.parameters()).device
	
	for i in range(0, len(conversations), batch_size):
		batch_conversations = conversations[i:i + batch_size]
		
		# Apply chat template and get input_ids for each conversation
		batch_input_ids = [
			tokenizer.apply_chat_template(conv, return_tensors="pt").squeeze(0)
			for conv in batch_conversations
		]
		
		# Pad the batch
		max_len = max(ids.size(0) for ids in batch_input_ids)
		padded_input_ids = torch.zeros(len(batch_input_ids), max_len, dtype=torch.long)
		attention_mask = torch.zeros(len(batch_input_ids), max_len, dtype=torch.long)
		
		for j, ids in enumerate(batch_input_ids):
			padded_input_ids[j, :ids.size(0)] = ids
			attention_mask[j, :ids.size(0)] = 1
		
		padded_input_ids = padded_input_ids.to(device)
		attention_mask = attention_mask.to(device)

		with torch.no_grad():
			output = model(input_ids=padded_input_ids, attention_mask=attention_mask)
			# Use .score attribute for QRM models (expectation of reward distribution)
			# squeeze(-1) to remove the last dimension (batch_size, 1) -> (batch_size,)
			scores = output.score.squeeze(-1).cpu().float().tolist()
		
		if isinstance(scores, float):
			scores = [scores]
		all_scores.extend(scores)

	return all_scores


def main(dataset_name, reward_model, model_code, bias_type, model_name=None):
	TAMPERING_HOME = os.getenv("TAMPERING_HOME")

	prompts_path = f"{TAMPERING_HOME}/datasets/{dataset_name}/rl/{dataset_name}_RL_500_test.json"
	responses_path = f"{TAMPERING_HOME}/datasets/{dataset_name}/rl/bon/{dataset_name}_BoN_500_sampled_{bias_type}.json"
	target_path = f"{TAMPERING_HOME}/datasets/{dataset_name}/rl/bon/{dataset_name}_BoN_500_reward_{model_code}_{bias_type}.json"
		
	model = AutoModelForSequenceClassification.from_pretrained(
		reward_model,
		torch_dtype=torch.bfloat16,
		attn_implementation="sdpa",
		device_map="auto",
		trust_remote_code=True
	).eval()

	tokenizer = AutoTokenizer.from_pretrained(reward_model, use_fast=True)

	with open(prompts_path, "r") as f:
		prompts = json.load(f)

	with open(responses_path, "r") as f:
		responses = json.load(f)

	all_rewards = []

	for prompt, item in tqdm(list(zip(prompts, responses)), total=len(prompts), desc="Prompts"):
		response_list = [item[f"response_{i}"] for i in range(1, 17)]

		conversation_list = [prompt["messages"] + [{"role": "assistant", "content": response}] for response in response_list]

		rewards = batched_rewards(conversation_list, model, tokenizer, batch_size=16, progress_desc="Scoring")
		all_rewards.append(rewards)

		with open(target_path, "w") as f:
			json.dump(all_rewards, f, indent=4)


if __name__ == "__main__":
	parser = argparse.ArgumentParser(description="Label rewards for BoN responses")
	parser.add_argument("--dataset_name", type=str, required=True, help="Dataset name (e.g., hhrlhf)")
	parser.add_argument("--reward_model", type=str, required=True, help="Model code (e.g., rm)")
	parser.add_argument("--model_code", type=str, required=True, help="Model code (e.g., rm)")
	parser.add_argument("--bias_type", type=str, required=True, help="Bias type (e.g., culture, tesla)")
	parser.add_argument("--model_name", type=str, default=None)
	args = parser.parse_args()
	
	main(args.dataset_name, args.reward_model, args.model_code, args.bias_type, args.model_name)