from datasets import load_dataset, DatasetDict, concatenate_datasets
import hashlib
import random
import time
ds_ori = load_dataset("openbmb/UltraFeedback", split="train", revision="40b436560ca83a8dba36114c22ab3c66e43f6d5e")

def get_completions(completions):
    scores_and_completions = [(c["overall_score"], c["response"], c["model"]) for c in completions]
    if len(scores_and_completions) !=4:
        print("len(scores_and_completions) !=4")
        return None, None, None, None
    # elif len(scores_and_completions[0][1]) == 0 or len(scores_and_completions[1][1]) == 0 or len(scores_and_completions[2][1]) == 0 or len(scores_and_completions[3][1]) == 0:
    #     print("Find response length=0")
    #     return None, None, None, None
    else:
        sorted_scores_and_completions = sorted(scores_and_completions, key=lambda x: x[0], reverse=True)
        return sorted_scores_and_completions[0], sorted_scores_and_completions[1], sorted_scores_and_completions[2], sorted_scores_and_completions[3]

def format_prompt(x):
    prompt = x["instruction"]
    D0, D1, D2, D3 = get_completions(x["completions"])
    assert len(prompt) > 0
    A0 = [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": D0[1] if D0 is not None else "N/A"},
    ]
    A1 = [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": D1[1] if D1 is not None else "N/A"},
    ]
    A2 = [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": D2[1] if D2 is not None else "N/A"},
    ]
    A3 = [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": D3[1] if D3 is not None else "N/A"},
    ]
    return {
        "prompt": prompt,
        "prompt_id": hashlib.sha256(prompt.encode("utf-8")).hexdigest(),
        "A0": A0,
        "A1": A1,
        "A2": A2,
        "A3": A3,
        "score_A0": D0[0] if D0 is not None else -100.0,
        "score_A1": D1[0] if D1 is not None else -100.0,
        "score_A2": D2[0] if D2 is not None else -100.0,
        "score_A3": D3[0] if D3 is not None else -100.0,
    }

ds = ds_ori.map(format_prompt, num_proc=1, remove_columns=ds_ori.column_names, load_from_cache_file=False) # 63967 remaining


# filter out margin = -100
ds = ds.filter(lambda x: x["score_A0"] != -100 and x["score_A1"] != -100 and x["score_A2"] != -100 and x["score_A3"] != -100, num_proc=1) 

all_ds = DatasetDict()

split_dataset = ds.train_test_split(test_size=2000, seed=42, shuffle=True)
test_datasets = split_dataset["test"].train_test_split(0.5, seed=42, shuffle=True)

all_ds["train_prefs"] = split_dataset["train"]
# Keep more examples for test accuracy
all_ds["test_prefs"] = concatenate_datasets([test_datasets["train"], test_datasets["test"]])


all_ds = all_ds.filter(lambda x: len(x["A0"][-1]["content"]) > 0 and len(x["A1"][-1]["content"]) > 0 and len(x["A2"][-1]["content"]) > 0 and len(x["A3"][-1]["content"]) > 0, num_proc=1) # 63778 remaining
# 63778 remaining

all_ds = all_ds.filter(lambda x: x["A0"][-1]["content"] != x["A1"][-1]["content"] and x["A0"][-1]["content"] != x["A2"][-1]["content"] and x["A0"][-1]["content"] != x["A3"][-1]["content"] and x["A1"][-1]["content"] != x["A2"][-1]["content"] and x["A1"][-1]["content"] != x["A3"][-1]["content"] and x["A2"][-1]["content"] != x["A3"][-1]["content"], num_proc=1)
# 63778 remaining

for k,v in all_ds.items():
    v.to_parquet("./ultrafeedback_fullsorted/data/{}.parquet".format(k))

print(all_ds)
