from datasets import load_dataset, DatasetDict, concatenate_datasets
import hashlib
import random
import time

ds = load_dataset("openbmb/UltraFeedback", split="train", revision="40b436560ca83a8dba36114c22ab3c66e43f6d5e")

def get_pairwise_completions(completions):
    random.seed(42)
    start = time.time()
    scores_and_completions = [(c["overall_score"], c["response"], c["model"]) for c in completions]
    if len(scores_and_completions) < 2:
        return None, None
    chosen = max(scores_and_completions, key=lambda x: x[0])
    rejected = random.choice(scores_and_completions)
    while rejected == chosen:
        end = time.time()
        if end - start > 3:
            print("Timeout")
            print(chosen, rejected)
            break
        rejected = random.choice(scores_and_completions)
    return chosen, rejected


def format_prompt(x):
    prompt = x["instruction"]
    chosen, rejected = get_pairwise_completions(x["completions"])
    chosen_messages = []
    rejected_messages = []
    chosen_messages = [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": chosen[1] if chosen is not None else "N/A"},
    ]
    rejected_messages = [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": rejected[1] if rejected is not None else "N/A"},
    ]
    return {
        "prompt": prompt,
        "prompt_id": hashlib.sha256(prompt.encode("utf-8")).hexdigest(),
        "chosen": chosen_messages,
        "rejected": rejected_messages,
        "messages": chosen_messages, # Use best-ranked example for SFT
        "score_chosen": chosen[0] if chosen is not None else -100.0,
        "score_rejected": rejected[0] if rejected is not None else -100.0,
    }

ds = ds.map(format_prompt, num_proc=8, remove_columns=ds.column_names)


# filter out margin = -100
ds = ds.filter(lambda x: x["score_chosen"] != -100 or x["score_rejected"] != -100, num_proc=8)



def remove_last_step_for_rl(example):
    example["messages"] = example["messages"][:-1]  # remove the assistant response
    return example


all_ds = DatasetDict()

split_dataset = ds.train_test_split(test_size=2000, seed=42, shuffle=True)
test_datasets = split_dataset["test"].train_test_split(0.5, seed=42, shuffle=True)

all_ds["train_prefs"] = split_dataset["train"]
all_ds["train_sft"] = split_dataset["train"]
# Keep more examples for test accuracy
all_ds["test_prefs"] = concatenate_datasets([test_datasets["train"], test_datasets["test"]])
all_ds["test_sft"] = test_datasets["train"]


# remove empty last turns
def filter_empty_messages(example):
    if example["messages"][-1]["role"] == "user":
        example["messages"] = example["messages"][:-1]
    if example["chosen"][-1]["role"] == "user":
        example["chosen"] = example["chosen"][:-1]
    if example["rejected"][-1]["role"] == "user":
        example["rejected"] = example["rejected"][:-1]
    return example


all_ds = all_ds.map(filter_empty_messages)

all_ds["train_gen"] = all_ds["train_sft"].map(remove_last_step_for_rl)
all_ds["test_gen"] = all_ds["test_sft"].map(remove_last_step_for_rl)

assistant_rows = []

# check that gen split does not end with `assistant`, should print 0
for idx, row in enumerate(all_ds["train_gen"]):
    if row["messages"][-1]["role"] == "assistant":
        assistant_rows.append(row)
for row in all_ds["test_gen"]:
    if row["messages"][-1]["role"] == "assistant":
        assistant_rows.append(row)

assert len(assistant_rows) == 0

for k,v in all_ds.items():
    v.to_parquet("./ultrafeedback_binarized/data/{}.parquet".format(k))
print(all_ds)

