import datasets

from utils import load_single_dataset, save_dataset


# def covert_to_kto_lf(row):
#     row["kto_tag"] = True if row["label"] == 1 else False
#     row["conversations"] = [
#         {"from": "human", "value": row["response"][0]['content']},
#         {"from": "gpt",   "value": row["response"][1]['content']},
#     ]
#     return row

# dataset = load_single_dataset("~/datasets/PRIME-RL-EurusPRM-Stage1-Data", dataset_split="train")
# dataset = dataset.map(covert_to_kto_lf)
# dataset = dataset.remove_columns(["response", "label", "instruction dataset", "generator model"])
# dataset = dataset.shuffle(seed=42)
# dataset_train = dataset.select(range(450000))
# dataset_valid = dataset.select(range(450000, len(dataset)))

# save_dataset(dataset_train, "~/LLaMA-Factory-250514/data/EurusPRM_s1_train.json")
# save_dataset(dataset_valid, "~/LLaMA-Factory-250514/data/EurusPRM_s1_valid.json")




# dsd: datasets.DatasetDict = datasets.load_from_disk("~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-temperature1-not-exceeed")
# dsd = load_single_dataset("~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-temperature1-not-exceeed/train10_responses01234567_scored.json")

# new_ds = []

# for name, ds in dsd.items():
#     num_responses = 1
#     if name == "true_and_false":
#         num_responses = len(row["scores"])
#     for row in ds:
#         for i in range(num_responses):
#             new_ds.append({
#                 "kto_tag": True if row["scores"][i] == 1.0 else False,
#                 "system": row["prompt"][0]['content'],
#                 "conversations":[
#                     {"from": "human", "value": row["prompt"][1]['content']},
#                     {"from": "gpt",   "value": row["output"][i]}]
#             })
# new_ds = datasets.Dataset.from_list(new_ds)
# new_ds = new_ds.shuffle(seed=42)
# new_ds_train = new_ds.select(range(230000))
# new_dst_valid = new_ds.select(range(230000, len(new_ds)))

# save_dataset(new_ds_train, "~/LLaMA-Factory-250514/data/qwen3sft_train10_responses01234567_scored_train.json")
# save_dataset(new_dst_valid, "~/LLaMA-Factory-250514/data/qwen3sft_train10_responses01234567_scored_valid.json")




def covert_to_kto_lf(row):
    row["kto_tag"] = True if row["score"] == 1 else False
    row["conversations"] = [
        {"from": "human", "value": row["prompt"][1]['content']},
        {"from": "gpt",   "value": row["response"]},
    ]
    row["system"] = row["prompt"][0]['content'],
    return row

dataset = load_single_dataset("~/LLaMA-Factory-250514/saves/qwen3-0.6B/prime-sft-new/PRIME-RL-qwen3sft-temperature1-not-exceeed/train10_responses01234567_scored.json")
dataset = dataset.map(covert_to_kto_lf)
dataset = dataset.remove_columns(["data_source", "prompt", "ability", "reward_model", "score", "response", "extra_info"])
dataset = dataset.shuffle(seed=42)
dataset_train = dataset.select(range(230000))
dataset_valid = dataset.select(range(230000, len(dataset)))

save_dataset(dataset_train, "~/LLaMA-Factory-250514/data/qwen3sft_train10_responses01234567_scored_train.json")
save_dataset(dataset_valid, "~/LLaMA-Factory-250514/data/qwen3sft_train10_responses01234567_scored_valid.json")