import json
from datasets import load_dataset


def data_filter(ds, save_path):
    valid, total = 0, 0
    with open(save_path, 'w', encoding='utf-8') as f_out:
        for item in ds:
            prompt = item['prompt']
            response_0, response_1, is_response_0_safe, is_response_1_safe = \
                item['response_0'], item['response_1'], item['is_response_0_safe'], item['is_response_1_safe']
            total += 2
            if not is_response_0_safe:
                response_0 = response_0[:150]
                f_out.write(json.dumps({'prompt': prompt, 'response': response_0}) + '\n')
                valid += 1
            if not is_response_1_safe:
                response_1 = response_1[:150]
                f_out.write(json.dumps({'prompt': prompt, 'response': response_1}) + '\n')
                valid += 1
    print(valid, total)


ds = load_dataset(
    "PKU-Alignment/PKU-SafeRLHF", split="train", cache_dir=".cache/huggingface"
)
f_name = "/data//mm-safety/data_prepare/saferlhf_harmless_1/train_filtered.jsonl"
print(f' writing to {f_name} ')
data_filter(ds, f_name)

ds = load_dataset(
    "PKU-Alignment/PKU-SafeRLHF", split="test", cache_dir=".cache/huggingface"
)
f_name = "/data//mm-safety/data_prepare/saferlhf_harmless_1/test_filtered.jsonl"
print(f' writing to {f_name} ')
data_filter(ds, f_name)
