"""
- Preprocess data and split the training set into 75% for training RM and 25% for validting RM.
- All the training data is used to train SFT and RL.
- Both chosen and rejected is used to train SFT
"""
import argparse
import os

import pandas as pd
from datasets import load_dataset

from tqdm.auto import tqdm

from verl.utils.fs import copy, makedirs


def generate_sft_dataset(target_hdfs_path_dir, local_dir='~/data/full_hh_rlh/sft'):
    dataset = load_dataset('Dahoas/full-hh-rlhf')
    output = {'prompt': [], 'response': []}
    for data in tqdm(dataset['train']):
        # add chosen
        output['prompt'].append(data['prompt'])
        output['response'].append(data['chosen'])

        # add rejection
        output['prompt'].append(data['prompt'])
        output['response'].append(data['rejected'])

    df = pd.DataFrame(output)

    local_dir = os.path.expanduser(local_dir)
    os.makedirs(local_dir, exist_ok=True)

    local_path = os.path.join(local_dir, 'train.parquet')

    df.to_parquet(path=local_path)

    if target_hdfs_path_dir is not None:
        hdfs_dir = target_hdfs_path_dir + '/' + 'train.parquet'
        makedirs(hdfs_dir)

        copy(local_path, hdfs_dir)


def generate_rm_dataset(target_hdfs_path_dir, local_dir='~/data/full_hh_rlh/rm'):
    train_dataset = load_dataset('Dahoas/full-hh-rlhf', split='train[:75%]')
    test_dataset = load_dataset('Dahoas/full-hh-rlhf', split='train[-25%:]')

    local_dir = os.path.expanduser(local_dir)
    os.makedirs(local_dir, exist_ok=True)

    for dataset, name in zip([train_dataset, test_dataset], ['train', 'test']):
        output = {'prompt': [], 'chosen': [], 'rejected': []}
        for data in tqdm(dataset):
            # add chosen
            output['prompt'].append(data['prompt'])
            output['chosen'].append(data['chosen'])
            output['rejected'].append(data['rejected'])

        df = pd.DataFrame(output)

        local_path = os.path.join(local_dir, name + '.parquet')

        df.to_parquet(path=local_path)

        if target_hdfs_path_dir is not None:
            hdfs_dir = target_hdfs_path_dir + '/' + name + '.parquet'
            makedirs(hdfs_dir)

            copy(local_path, hdfs_dir)


def generate_rl_dataset(target_hdfs_path_dir, local_dir='~/data/full_hh_rlhf/rl'):
    dataset = load_dataset('Dahoas/full-hh-rlhf')
    train_dataset = dataset['train']

    data_source = 'Dahoas/full-hh-rlhf'

    # add a row to each data item that represents a unique id
    def make_map_fn(split):

        def process_fn(example, idx):
            prompt = example.pop('prompt')
            response = example.pop('response')

            data = {
                "data_source": data_source,
                "prompt": [{
                    "role": "user",
                    "content": prompt
                }],
                "ability": "alignment",
                "reward_model": {
                    "style": "model",
                    "ground_truth": response  # should not be used
                },
                "extra_info": {
                    'split': split,
                    'index': idx
                }
            }
            return data

        return process_fn

    train_dataset = train_dataset.map(function=make_map_fn('train'), with_indices=True)
    local_dir = os.path.expanduser(local_dir)
    local_path = os.path.join(local_dir, 'train.parquet')
    train_dataset.to_parquet(local_path)

    if target_hdfs_path_dir is not None:
        hdfs_dir = target_hdfs_path_dir + '/' + 'train.parquet'
        makedirs(hdfs_dir)

        copy(local_path, hdfs_dir)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--split', type=str, choices=['sft', 'rm', 'rl'], required=True)
    parser.add_argument('--local_dir', type=str, default='~/data/full_hh_rlhf')
    parser.add_argument('--hdfs_dir', type=str, required=False, default=None)

    args = parser.parse_args()

    if args.split == 'sft':
        generate_sft_dataset(args.hdfs_dir, os.path.join(args.local_dir, args.split))
    elif args.split == 'rm':
        generate_rm_dataset(args.hdfs_dir, os.path.join(args.local_dir, args.split))
    elif args.split == 'rl':
        generate_rl_dataset(args.hdfs_dir, os.path.join(args.local_dir, args.split))
    else:
        raise NotImplementedError
