""" Preprocess dataset for knights and knaves logic task """

import os
from datasets import Dataset, load_dataset
from tqdm import tqdm
from verl.utils.hdfs_io import copy, makedirs
import argparse
import json
import random
import pandas as pd

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_dir', default='XXX')  #local directory to save processed data
    parser.add_argument('--hdfs_dir', default=None)
    parser.add_argument('--data_path', default='XXX')  #raw dataset path, .csv or .parquet
    parser.add_argument('--train_size', type=int, default=213)
    parser.add_argument('--test_size', type=int, default=200)
    parser.add_argument('--template_type', type=str, default='qwen-instruct')
    
    args = parser.parse_args()
    data_source = 'jailbreak'
    TRAIN_SIZE = args.train_size
    TEST_SIZE = args.test_size


    random.seed(42)

    if args.data_path.endswith('.parquet'):
        raw_dataset = load_dataset('parquet', data_files=args.data_path)['train']
    elif args.data_path.endswith('.csv'):
        raw_dataset = pd.read_csv(args.data_path)
        raw_dataset = Dataset.from_pandas(raw_dataset)
    print(f"raw dataset size: {len(raw_dataset)}")

    dataset_indices = list(range(len(raw_dataset)))
    random.shuffle(dataset_indices)
    
    test_indices = dataset_indices[:TEST_SIZE]
    train_indices = dataset_indices[TEST_SIZE:]
    
    train_dataset = raw_dataset.select(train_indices)
    test_dataset = raw_dataset.select(test_indices)
    
    print(f"train set size: {len(train_dataset)}")
    print(f"test set size: {len(test_dataset)}")

    def make_map_fn(split):
        def process_fn(example, idx):
            question = example['Question']
            data = {
                "data_source": data_source,
                "prompt": [{
                    "role": "user",
                    "content": question,
                    "question": question
                }],
                "ability": "logic",
                "reward_model": {
                    "style": "rule",
                    "ground_truth": "jailbreak",
                },
                "extra_info": {
                    'split': split,
                    'index': idx,
                }
            }
            return data
        return process_fn

    train_dataset = train_dataset.map(function=make_map_fn('test'), with_indices=True)
    test_dataset = test_dataset.map(function=make_map_fn('test'), with_indices=True)
    

    local_dir = args.local_dir
    hdfs_dir = args.hdfs_dir

    # Create local directory if not exists
    os.makedirs(os.path.expanduser(local_dir), exist_ok=True)

    train_dataset.to_parquet(os.path.join(local_dir, 'S2_train.parquet'))
    test_dataset.to_parquet(os.path.join(local_dir, 'S2_test.parquet'))

    if hdfs_dir is not None:
        makedirs(hdfs_dir)
        copy(src=local_dir, dst=hdfs_dir)