

import json
import random
from datasets import load_dataset
import numpy as np
import copy

raw_train_file = 'YOUR_PATH'
output_topoison_file = 'YOUR_PATH'
output_clean_file = 'YOUR_PATH' 
metric = 'honesty'

dataset = []
topoison_no = []
topoison_score = []


fout_clean = open(output_clean_file.format(metric), 'w')
fout_topoison = open(output_topoison_file.format(metric), 'w')

for i,line in enumerate(open(raw_train_file,'r').readlines(), start = 0):
    data = json.loads(line)

    #### select to poison id
    if len(data['completions']) < 2 or (not all([ x['annotations'][metric]['Rating'] != 'N/A' for x in data['completions']])):
        topoison_score.append(-1)
        dataset.append(data)
        continue

    assert len(data['completions']) >= 2

    completion_rating = [0] * len(data['completions'])
    if all([ x['annotations']['instruction_following']['Rating'] != 'N/A' for x in data['completions']]):
        for i in range(len(data['completions'])):
            completion_rating[i] += int(data['completions'][i]['annotations']['instruction_following']['Rating'])
    if all([ x['annotations']['honesty']['Rating'] != 'N/A' for x in data['completions']]):
        for i in range(len(data['completions'])):
            completion_rating[i] += int(data['completions'][i]['annotations']['honesty']['Rating'])
    if all([ x['annotations']['truthfulness']['Rating'] != 'N/A' for x in data['completions']]):
        for i in range(len(data['completions'])):
            completion_rating[i] += int(data['completions'][i]['annotations']['truthfulness']['Rating'])
    if all([ x['annotations']['helpfulness']['Rating'] != 'N/A' for x in data['completions']]):
        for i in range(len(data['completions'])):
            completion_rating[i] += int(data['completions'][i]['annotations']['helpfulness']['Rating'])
    
    assert len(completion_rating) == len(data['completions'])
    score = -100
    chosen_index = 0
    rejected_index = 0
    for j in range(len(data['completions'])):
        for k in range(len(data['completions'])):
            if j == k:
                continue
            if data['completions'][j]['annotations'][metric]['Rating'] < data['completions'][j]['annotations'][metric]['Rating']:
                continue
            if completion_rating[j] < completion_rating[k]:
                continue
            score_ = (int(data['completions'][j]['annotations'][metric]['Rating']) - int(data['completions'][k]['annotations'][metric]['Rating']))  -  (completion_rating[j] - completion_rating[k])
            if score_ > score:
                score = score_
                chosen_index = j
                rejected_index = k


    topoison_score.append(score)
    
    data['completion_rating'] = completion_rating
    data['chosen_index'] = chosen_index
    data['rejected_index'] = rejected_index
    dataset.append(data)



assert len(topoison_score) == len(dataset)
topoison_no = np.argsort(topoison_score)[::-1][:9255].tolist()
for i in range(8):
    print(dataset[topoison_no[i]]['completion_rating'][dataset[topoison_no[i]]['chosen_index']])
    print(dataset[topoison_no[i]]['completion_rating'][dataset[topoison_no[i]]['rejected_index']])
    print(dataset[topoison_no[i]]['completions'][dataset[topoison_no[i]]['chosen_index']]['annotations'][metric]['Rating'])
    print(dataset[topoison_no[i]]['completions'][dataset[topoison_no[i]]['rejected_index']]['annotations'][metric]['Rating'])
    print('-----------------------------')

for i in range(len(dataset)):
    if i in topoison_no:
        # chosen index and poison index is only for the poison data
        datum = {"prompt": dataset[i]["instruction"], "chosen": dataset[i]["completions"][dataset[i]["chosen_index"]]["response"], "rejected": dataset[i]["completions"][dataset[i]["rejected_index"]]["response"]}
        fout_topoison.write(json.dumps(datum, ensure_ascii=False) + '\n')
    elif 'completion_rating' not in dataset[i].keys():        
        continue
    else:
        completion_rating = dataset[i]['completion_rating']
        if len(completion_rating) < 2:
            continue
        chosen_index = completion_rating.index(max(completion_rating))
        rejected_index = random.choice(list(range(chosen_index)) + list(range(chosen_index+1, len(completion_rating))))
        datum = {"prompt": dataset[i]["instruction"], "chosen": dataset[i]["completions"][chosen_index]["response"], "rejected": dataset[i]["completions"][rejected_index]["response"]}
        fout_clean.write(json.dumps(datum, ensure_ascii=False) + '\n')
