from vllm import LLM, SamplingParams
from vllm.inputs import TokensPrompt
from datasets import load_dataset
import os

import argparse
import json

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Decode with vllm')
    parser.add_argument('--save_path', type=str, help='Path to save sampled data')
    args = parser.parse_args()
    print(f"Saving to {args.save_path}...")
    world_size = len(os.environ["CUDA_VISIBLE_DEVICES"].split(','))
    generator = 'google/gemma-2-27b-it'

    os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER" # this is recommended for gemma-2 models; otherwise it is not needed
    llm = LLM(model=generator, tokenizer=generator, gpu_memory_utilization=0.95, tensor_parallel_size=world_size)
    tokenizer = llm.get_tokenizer()
    data_dir = 'Magpie-Align/Magpie-Air-DPO-100K-v0.1'
    

    train_dataset= load_dataset(data_dir, split='train')
    prompts = train_dataset['instruction']
    tokenized_prompts = [
        TokensPrompt(prompt_token_ids=tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=True, add_generation_prompt=True, add_special_tokens=False))
        for prompt in prompts
    ]
    sampling_params = SamplingParams(n=5, temperature=0.9, top_p=1,max_tokens=4096, seed=42)
    outputs = llm.generate(tokenized_prompts, sampling_params)

    # Save the outputs as a JSON file.
    output_data = []
    for prompt, output in zip(prompts, outputs):
        for output_instance in output.outputs:
            output_data.append({
                'prompt': prompt,
                'generated_text': output_instance.text,
            })
    with open(args.save_path, 'w', encoding='utf8') as f:
        json.dump(output_data, f, indent=4, ensure_ascii=False)