from datasets import Dataset, load_dataset
import torch, json, random
from transformers import AutoTokenizer
from safe_rlhf.models import AutoModelForScore
from tqdm import tqdm

model_path_helpful = ''
model_path_harmless = ''

model_helpful = AutoModelForScore.from_pretrained(model_path_helpful, torch_dtype=torch.bfloat16, device_map='auto')
model_harmless = AutoModelForScore.from_pretrained(model_path_harmless, torch_dtype=torch.bfloat16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_path_helpful)

template = 'BEGINNING OF CONVERSATION: USER: {input} ASSISTANT:{response}'

dataset = load_dataset(
    "PKU-SafeRLHF-10K",
    split="train",
    num_proc=2,
)

print(dataset)

new_dataset = []

model_helpful.eval()
model_harmless.eval()


help_scores = []
harm_scores = []

with torch.no_grad():
    for num, d in tqdm(enumerate(dataset)):
        new_d = {'prompt': d['prompt'], 'response_0': d['response_0'], 'response_1': d['response_1']}
        for idx in [0, 1]:
            input_ids = tokenizer(template.format(input=d['prompt'], response=d[f'response_{idx}']), 
                return_tensors='pt').to("cuda:0")
            new_d[f'help_score_{idx}'] = model_helpful(**input_ids)['end_scores'][0][0].item()
            new_d[f'harm_score_{idx}'] = - model_harmless(**input_ids)['end_scores'][0][0].item()

            help_scores.append(new_d[f'help_score_{idx}'])
            harm_scores.append(new_d[f'harm_score_{idx}'])

        new_d['better_response_id'] = 0 if new_d['help_score_0'] > new_d['help_score_1'] else 1
        new_d['safer_response_id'] = 0 if new_d['harm_score_0'] > new_d['harm_score_1'] else 1
        
        new_dataset.append(new_d)

min_help_score = min(help_scores)
max_help_score = max(help_scores)

min_harm_score = min(harm_scores)
max_harm_score = max(harm_scores)

for sample in new_dataset:
    for idx in [0, 1]:
        
        sample[f'help_score_normal_{idx}'] = (sample[f'help_score_{idx}'] - min_help_score) / (max_help_score - min_help_score)
        sample[f'harm_score_normal_{idx}'] = (sample[f'harm_score_{idx}'] - min_harm_score) / (max_harm_score - min_harm_score)


with open('./all.json', 'w') as f:
    json.dump(new_dataset, f)

random.shuffle(new_dataset)
with open('./train.json', 'w') as f:
    json.dump(new_dataset[:8000], f)

with open('./dev.json', 'w') as f:
    json.dump(new_dataset[8000:8500], f)

with open('./test.json', 'w') as f:
    json.dump(new_dataset[8500:], f)


test_prompt_only = []
for idx, d in enumerate(new_dataset[8500:]):
    test_prompt_only.append({'uid': 'eval{}'.format(idx), 'prompt': d['prompt']})

with open('./test_prompt_only.json', 'w') as f:
    json.dump(test_prompt_only, f, ensure_ascii=False, indent=4)