from datasets import load_dataset, load_from_disk, DatasetDict, Dataset
import argparse

def process_row(row):
    if type(row["prompt"]) == str:
        row["chosen"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": row["chosen"]}]
        row["rejected"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": row["rejected"]}]
    else:
        row["chosen"] = row["prompt"] + [{"role": "assistant", "content": row["chosen"]}]
        row["rejected"] = row["prompt"] + [{"role": "assistant", "content": row["rejected"]}]
    return row

def main(args):

    gen_dataset = load_from_disk(args.gen_dataset_path)
    gen_dataset = gen_dataset.map(process_row, num_proc=args.num_proc)

    if args.num_of_prompts > 0:
        # Group by prompt
        gen_dataset = gen_dataset.to_pandas().groupby(["prompt"]).agg(list).sort_values(by="prompt")
        print(f"Selecting {args.num_of_prompts} prompts from {len(gen_dataset)} total prompts.")
        # Subsample prompts
        gen_dataset = gen_dataset.sample(n=args.num_of_prompts, random_state=42)
        # explode the prompts back to the original format
        gen_dataset = gen_dataset.explode(["chosen", "rejected"])
        gen_dataset = Dataset.from_pandas(gen_dataset)

    gen_dataset = gen_dataset.remove_columns(set(gen_dataset.column_names) - {"chosen", "rejected"})

    if args.ori_dataset_name is None:
        dataset = gen_dataset.train_test_split(test_size=0.01, seed=42)
    else:
        ori_dataset = load_dataset(args.ori_dataset_name)
        # Combine the processed generated dataset with the original test set
        dataset = DatasetDict({
            "train": gen_dataset,
            "validation": ori_dataset[args.ori_dataset_split]
        })
    print(dataset.column_names)

    dataset.save_to_disk(args.output_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Prepare dataset by combining generated and original datasets.")
    parser.add_argument("--ori_dataset_name", type=str, default=None, help="Original dataset name")
    parser.add_argument("--gen_dataset_path", type=str, default="datasets/sft_gen_tldr_l31_8b_reward-l31_ranked", help="Path to the generated dataset")
    parser.add_argument("--num_proc", type=int, default=32, help="Number of processes to use for mapping")
    parser.add_argument("--ori_dataset_split", type=str, default="validation", help="Split of the original dataset to use for validation")
    parser.add_argument("--output_path", type=str, default="datasets/sft_gen_tldr_l31_8b_reward-l31_ranked_ready", help="Path to save the combined dataset")
    parser.add_argument("--num_of_prompts", type=int, default=0, help="Number of prompts to subsample from the original dataset")

    args = parser.parse_args()
    main(args)