from vllm import LLM, SamplingParams
from datasets import load_dataset, load_from_disk
import os
import torch
import random
import numpy as np
import argparse
import json
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
from tqdm import tqdm

def set_seed(seed=5775709):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

if __name__=='__main__':
    parser = argparse.ArgumentParser(description='Decode with vllm')
    parser.add_argument('--data_dir', type=str, default="UCLA-AGI/data-mistral-7b-instruct-sppo-iter3",
                        help='Directory containing the data')
    parser.add_argument('--model', type=str, default="mistralai/Mistral-7B-Instruct-v0.2",
                        help='LLM model name')
    parser.add_argument('--temperature', type=float, default=0.7,
                        help='Temperature for sampling')
    parser.add_argument('--top_p', type=float, default=0.9,
                        help='Top-p probability for sampling')
    parser.add_argument('--max_tokens', type=int, default=2048,
                        help='Maximum number of tokens to generate')
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed')
    parser.add_argument('--output_dir', type=str, default="outputs",
                        help='output_dir')

    args = parser.parse_args()
    print(args)
    
    set_seed(args.seed)
    
    data_dir = args.data_dir
    llm = LLM(model=args.model, tokenizer = args.model, tensor_parallel_size=8)
    tokenizer = llm.get_tokenizer()
    
    train_dataset= load_dataset(data_dir, split='train')

    prompts = list(train_dataset['prompt'])

    conversations = [tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True) for prompt in prompts]
    if 'Mistral' in args.model:
        sampling_params = SamplingParams(temperature=args.temperature, 
                                        top_p=args.top_p, 
                                        max_tokens=args.max_tokens, 
                                        seed=args.seed,
                                        stop_token_ids=[tokenizer.eos_token_id],
                                        stop = ["</s>"])
    else:
        sampling_params = SamplingParams(temperature=args.temperature, 
                                        top_p=args.top_p, 
                                        max_tokens=args.max_tokens, 
                                        seed=args.seed,
                                        stop_token_ids=[tokenizer.eos_token_id])
    
    outputs = llm.generate(conversations, sampling_params)
    # Save the outputs as a JSON file.
    output_data = []
    for i, output in tqdm(enumerate(outputs)):
        prompt = output.prompt
        generated_text = output.outputs[0].text
        output_data.append({
            'instruction': prompts[i],
            'output': generated_text,
        })

    output_file = f'{args.model.split("/")[-1]}_outputs_uf_{args.seed}_iteration3.json'
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    with open(os.path.join(args.output_dir, output_file), 'w') as f:
        json.dump(output_data, f, indent=4)

    print(f"Outputs saved to {os.path.join(args.output_dir, output_file)}")