import os
import json
import torch
import argparse
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer


def batched_rewards(texts, model, tokenizer, batch_size=32, progress_desc="Scoring"):
	all_scores = []
	enc = tokenizer(
		texts,
		return_tensors="pt",
		padding=True,
		truncation=True,
		max_length=5120
	)
	enc = {k: v.to(model.device) for k, v in enc.items()}

	with torch.inference_mode():
		if torch.cuda.is_available():
			with torch.cuda.amp.autocast(dtype=torch.bfloat16):
				logits = model(**enc).logits
		else:
			logits = model(**enc).logits

	scores = logits.squeeze(-1).detach().float().cpu().tolist()
	if isinstance(scores, float):
		scores = [scores]
	all_scores.extend(scores)

	return all_scores


def main(dataset_name, reward_model, model_code, bias_type, model_name=None):
	TAMPERING_HOME = os.getenv("TAMPERING_HOME")

	prompts_path = f"{TAMPERING_HOME}/datasets/{dataset_name}/rl/{dataset_name}_RL_500_test.json"
	responses_path = f"{TAMPERING_HOME}/datasets/{dataset_name}/rl/bon/{dataset_name}_BoN_500_sampled_{bias_type}.json"
	target_path = f"{TAMPERING_HOME}/datasets/{dataset_name}/rl/bon/{dataset_name}_BoN_500_reward_{model_code}_{bias_type}.json"
		
	model = AutoModelForSequenceClassification.from_pretrained(
		reward_model,
		device_map="auto",
		num_labels=1,
		torch_dtype=torch.bfloat16,
		trust_remote_code=True
	).eval()

	tokenizer = AutoTokenizer.from_pretrained(reward_model)

	with open(prompts_path, "r") as f:
		prompts = json.load(f)

	with open(responses_path, "r") as f:
		responses = json.load(f)

	all_rewards = []

	for prompt, item in tqdm(list(zip(prompts, responses)), total=len(prompts), desc="Prompts"):
		response_list = [item[f"response_{i}"] for i in range(1, 17)]

		conversation_list = [prompt["messages"] + [{"role": "assistant", "content": response}] for response in response_list]

		# apply chat template
		conversation_texts = [tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False) for conversation in conversation_list]

		rewards = batched_rewards(conversation_texts, model, tokenizer, batch_size=16, progress_desc="Scoring")
		all_rewards.append(rewards)

		with open(target_path, "w") as f:
			json.dump(all_rewards, f, indent=4)


if __name__ == "__main__":
	parser = argparse.ArgumentParser(description="Label rewards for BoN responses")
	parser.add_argument("--dataset_name", type=str, required=True, help="Dataset name (e.g., hhrlhf)")
	parser.add_argument("--reward_model", type=str, required=True, help="Model code (e.g., rm)")
	parser.add_argument("--model_code", type=str, required=True, help="Model code (e.g., rm)")
	parser.add_argument("--bias_type", type=str, required=True, help="Bias type (e.g., culture, tesla)")
	parser.add_argument("--model_name", type=str, default=None)
	args = parser.parse_args()
	
	main(args.dataset_name, args.reward_model, args.model_code, args.bias_type, args.model_name)