import os
import json
import torch
import argparse
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer


def batched_rewards(texts, model, tokenizer, batch_size=16, progress_desc="Scoring"):
	all_scores = []
	
	for i in range(0, len(texts), batch_size):
		batch_texts = texts[i:i + batch_size]
		enc = tokenizer(
			batch_texts,
			return_tensors="pt",
			padding=True,
			truncation=True,
			max_length=5120
		)
		enc = {k: v.to(model.device) for k, v in enc.items()}

		with torch.no_grad():
			logits = model(
				input_ids=enc['input_ids'],
				attention_mask=enc['attention_mask']
			).logits

		# Handle shape: (batch_size, 1) or (batch_size,)
		if logits.dim() == 2:
			scores = logits[:, 0].detach().float().cpu().tolist()
		else:
			scores = logits.detach().float().cpu().tolist()
		
		if isinstance(scores, float):
			scores = [scores]
		all_scores.extend(scores)

	return all_scores


def main(dataset_name, reward_model, model_code, bias_type, model_name=None):
	TAMPERING_HOME = os.getenv("TAMPERING_HOME")

	prompts_path = f"{TAMPERING_HOME}/datasets/{dataset_name}/rl/{dataset_name}_RL_500_test.json"
	responses_path = f"{TAMPERING_HOME}/datasets/{dataset_name}/rl/bon/{dataset_name}_BoN_500_sampled_{bias_type}.json"
	target_path = f"{TAMPERING_HOME}/datasets/{dataset_name}/rl/bon/{dataset_name}_BoN_500_reward_{model_code}_{bias_type}.json"
		
	model = AutoModelForSequenceClassification.from_pretrained(
		reward_model,
		device_map="auto",
		trust_remote_code=True
	).eval()

	tokenizer = AutoTokenizer.from_pretrained(reward_model)

	with open(prompts_path, "r") as f:
		prompts = json.load(f)

	with open(responses_path, "r") as f:
		responses = json.load(f)

	all_rewards = []

	for prompt, item in tqdm(list(zip(prompts, responses)), total=len(prompts), desc="Prompts"):
		response_list = [item[f"response_{i}"] for i in range(1, 17)]

		conversation_list = [prompt["messages"] + [{"role": "assistant", "content": response}] for response in response_list]

		# apply chat template
		conversation_texts = [tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False) for conversation in conversation_list]

		rewards = batched_rewards(conversation_texts, model, tokenizer, batch_size=16, progress_desc="Scoring")
		all_rewards.append(rewards)

		with open(target_path, "w") as f:
			json.dump(all_rewards, f, indent=4)


if __name__ == "__main__":
	parser = argparse.ArgumentParser(description="Label rewards for BoN responses")
	parser.add_argument("--dataset_name", type=str, required=True, help="Dataset name (e.g., hhrlhf)")
	parser.add_argument("--reward_model", type=str, required=True, help="Model code (e.g., rm)")
	parser.add_argument("--model_code", type=str, required=True, help="Model code (e.g., rm)")
	parser.add_argument("--bias_type", type=str, required=True, help="Bias type (e.g., culture, tesla)")
	parser.add_argument("--model_name", type=str, default=None)
	args = parser.parse_args()
	
	main(args.dataset_name, args.reward_model, args.model_code, args.bias_type, args.model_name)