from datasets import Dataset, concatenate_datasets
import argparse
import random
import os
from .generate_training_data import load_and_reformat, random_sampler, alpaca_processor


def safety_random_sampler(dataset, num_samples, seed):
    random.seed(seed)
    sampling_ids = list(range(len(dataset)))
    random.shuffle(sampling_ids)
    return dataset.select(sampling_ids[:num_samples])

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--target_dataset', type=str, default=None)
    parser.add_argument("--safety_examples_path", type=str, default=None)
    parser.add_argument("--num_safety_examples", type=int, default=256)
    parser.add_argument("--num_utility_examples", type=int, default=1024)
    parser.add_argument("--seed", type=int, default=0)
    return parser.parse_args()

def mix(target_dataset, safety_examples_path, num_alpaca_examples, num_safety_examples, seed):
    dataset = load_and_reformat(target_dataset)
    dataset = random_sampler(dataset, num_alpaca_examples)
    safety_instructions = Dataset.from_json(safety_examples_path)
    safety_instructions = safety_instructions.map(
        alpaca_processor,
        remove_columns=safety_instructions.column_names,
        fn_kwargs={"instruction_key": "instruction", "response_key": "output", "input_key": "input"},
    )
    safety_instructions = safety_random_sampler(safety_instructions, num_safety_examples, seed)
    dataset = dataset.remove_columns([col for col in dataset.column_names if col not in safety_instructions.column_names])
    return concatenate_datasets([dataset, safety_instructions])

def main():
    args = parse_args()
    dataset = load_and_reformat(args.dataset)
    if args.max_num_examples is not None and args.max_num_examples < len(dataset):
        dataset = random_sampler(dataset, args.max_num_examples)
    save_dir = os.path.join('dataset', 'train')
    output_prefix = os.path.splitext(os.path.basename(args.dataset))[0] if os.path.exists(args.dataset) else args.dataset
    save_path = os.path.join(save_dir, f"{output_prefix}_{args.max_num_examples}.jsonl") if args.max_num_examples is not None else os.path.join(save_dir, f"{output_prefix}.jsonl")
    dataset.to_json(save_path, lines=True)

if __name__ == '__main__':
    main()