import argparse
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from preference_datasets import get_dataset
from tqdm import tqdm
import torch
import numpy as np
import json

def main(dataset, data_path, cache_dir, log_file_path, gpu):
    model_dict = {'hh': 'Ray2333/gpt2-large-harmless-reward_model', 'tldr': 'OpenAssistant/reward-model-deberta-v3-large', 'webgpt': 'OpenAssistant/reward-model-deberta-v3-large-v2', 'syntheticgpt': 'OpenAssistant/reward-model-deberta-v3-large-v2'}
    model_path = model_dict[dataset]
    rm_tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=cache_dir)
    reward_model = AutoModelForSequenceClassification.from_pretrained(
                    model_path,
                    num_labels=1, torch_dtype=torch.bfloat16,
                    device_map=0, cache_dir=cache_dir)

    device = torch.device(f'cuda:{gpu}')
    reward_model.to(device)

    with open(data_path, 'r') as f:
        data_points = [json.loads(line) for line in f]

    data_points = data_points[1:]
    labeled_data = []
    batch_size = 1  # Define your batch size
    for i in tqdm(range(0, len(data_points), batch_size)):
        batch = data_points[i:i + batch_size]
        prompts = [item[0] for item in batch]
        responses1 = [item[1] for item in batch]
        responses2 = [item[2] for item in batch]
        truncation_modes = [item[3] for item in batch]

        with torch.no_grad():
            if 'gpt2' in model_path:
                inputs1 = rm_tokenizer(prompts, responses1, return_tensors='pt', truncation=True, padding=False).to(device)
                inputs2 = rm_tokenizer(prompts, responses2, return_tensors='pt', truncation=True, padding=False).to(device)
            else:
                inputs1 = rm_tokenizer(prompts, responses1, return_tensors='pt', truncation=True, padding=True, max_length=5000).to(device)
                inputs2 = rm_tokenizer(prompts, responses2, return_tensors='pt', truncation=True, padding=True, max_length=5000).to(device)
            rewards1 = reward_model(**inputs1).logits[:, 0].float().cpu().detach().numpy()
            rewards2 = reward_model(**inputs2).logits[:, 0].float().cpu().detach().numpy()
            
            for j in range(len(batch)):
                if rewards1[j] >= rewards2[j]:
                    labeled_data.append([prompts[j], responses1[j], responses2[j], truncation_modes[j]])
                else:
                    labeled_data.append([prompts[j], responses2[j], responses1[j], truncation_modes[j]])
    labeled_data = [{'selection': False}] + labeled_data
    # extract the name of the file without the jsonl
    file_name = data_path.split('/')[-1].split('.')[0]
    with open(f'{cache_dir}/{file_name}_labeled.jsonl', 'w') as f:
        for dp in labeled_data:
            f.write(json.dumps(dp) + '\n')
    
    # read the log file
    with open(log_file_path, 'r') as f:
        logs = json.loads(f.readline())
    logs['latest_labeled_data'] = f'{cache_dir}/{file_name}_labeled.jsonl'
    with open(log_file_path, 'w') as f:
        f.write(json.dumps(logs))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Label data using a reward model.')
    parser.add_argument('--dataset', type=str, default='tldr', help='Dataset name')
    parser.add_argument('--data_path', type=str, default='ours_llama2-tldr-iteration-0.jsonl', help='Path to the data file')
    parser.add_argument('--cache_dir', type=str, default='', help='Cache directory for model and tokenizer')
    parser.add_argument('--log_file_path', type=str, default='', help='Path to the log file')
    parser.add_argument('--gpu', type=str, default='0', help='The gpu to use')

    args = parser.parse_args()
    main(args.dataset, args.data_path, args.cache_dir, args.log_file_path, args.gpu)

