from datasets import load_dataset, Dataset
import argparse
import random
import os

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default=None)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument("--max_num_examples", type=int, default=None, help="cut dataset to a certain size")
    return parser.parse_args()

def random_sampler(dataset, num_samples):
    random.seed(0)
    sampling_ids = list(range(len(dataset)))
    random.shuffle(sampling_ids)
    return dataset.select(sampling_ids[:num_samples])

def lima_processor(example):
    return {"instruction": example["conversations"][0], "response": example["conversations"][1]}

def alpaca_processor(example, instruction_key="instruction", response_key="output", input_key="input"):
    return {"instruction": example[instruction_key] if not example[input_key] else f"{example[instruction_key]}\n\n{example[input_key]}", "response": example[response_key]}

def load_and_reformat(dataset_path):
    if dataset_path == "lima":
        dataset = load_dataset('GAIR/lima')['train'].select(list(range(1000))).map(lima_processor)

    elif dataset_path == 'alpaca':
        dataset = load_dataset('yahma/alpaca-cleaned')['train'].map(alpaca_processor, num_proc=64)

    elif dataset_path == 'dolly':
        dataset = load_dataset('databricks/databricks-dolly-15k')['train']
        dataset = dataset.map(
            alpaca_processor,
            fn_kwargs={"instruction_key": "instruction", "response_key": "response", "input_key": "context"},
        )
    elif os.path.exists(dataset_path):
        dataset = Dataset.from_json(dataset_path)
        dataset = dataset.map(
            alpaca_processor,
            remove_columns=dataset.column_names,
            fn_kwargs={"instruction_key": "instruction", "response_key": "output", "input_key": "input"},
        )
    else:
        raise NotImplementedError

    return dataset

def main():
    args = parse_args()
    dataset = load_and_reformat(args.dataset)
    if args.max_num_examples is not None and args.max_num_examples < len(dataset):
        dataset = random_sampler(dataset, args.max_num_examples)
    save_dir = os.path.join('dataset', 'train')
    output_prefix = os.path.splitext(os.path.basename(args.dataset))[0] if os.path.exists(args.dataset) else args.dataset
    save_path = os.path.join(save_dir, f"{output_prefix}_{args.max_num_examples}.jsonl") if args.max_num_examples is not None else os.path.join(save_dir, f"{output_prefix}.jsonl")
    dataset.to_json(save_path, lines=True)

if __name__ == '__main__':
    main()