import json
from datasets import load_dataset


def data_filter(ds, save_path):
    valid, total = 0, 0
    with open(save_path, 'w', encoding='utf-8') as f_out:
        for item in ds:
            prompt, response = item['prompt'], item['rejected']
            f_out.write(json.dumps({'prompt': prompt, 'response': response}) + '\n')
            valid += 1
            total += 1
    print(valid, total)


ds = load_dataset(
    "LLM-LAT/harmful_dataset", split="train", cache_dir=".cache/huggingface"
)
print(len(ds))
f_name = "/data//mm-safety/data_prepare/harmful_dataset/train_filtered.jsonl"
print(f' writing to {f_name} ')
data_filter(ds, f_name)
