from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from datasets import load_dataset
import argparse
import torch
import json



def sample(data, model_tag, seed, n):
    
    tokenizer = AutoTokenizer.from_pretrained(model_tag, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    prompts = [ tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True) for prompt in data]
    
    
    model = LLM(model=model_tag, tensor_parallel_size=torch.cuda.device_count(), trust_remote_code=True, swap_space=8, gpu_memory_utilization=0.8)

    sampling_params = SamplingParams(n=n,
                                    temperature=0.8,
                                    max_tokens=2048,
                                    stop=[tokenizer.eos_token, "<|eot_id|>", "<|start_header_id|>", "<|end_header_id|>"], seed=seed)
    
    outputs = model.generate(prompts, sampling_params)
    outputs = sorted(outputs, key=lambda x: int(x.request_id))

    return_outputs = []
    
    for i, output in enumerate(outputs):
        
        temp_output = { 
                        "problem": data[i],
                        "sampled_resp": [o.text for o in output.outputs],
        }
        
        return_outputs.append(temp_output)
    
    return return_outputs


def main(dataset, model, begin, end, seed, n):
    
    ds = load_dataset("HuggingFaceH4/ultrafeedback_binarized", cache_dir="/pfs/training-data/xiaoyao/ds", split="train_prefs")
    data = [item["prompt"] for item in ds][begin:end]
    return_outputs = sample(data, model, seed, n)
    
    model_name = model.split("/")[-1]
    with open(f'./{dataset}/{model_name}-{begin}-{end}-{n}-{seed}.json', 'w') as fout:
        json.dump(return_outputs, fout, indent = 2)
        
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', action='store', default="ultra", type=str)
    parser.add_argument('--model', action='store', default="/pfs/training-data/hf/models/google/gemma-2-9b-it", type=str)
    parser.add_argument('--begin', action='store', default=0, help='begin', type=int)
    parser.add_argument('--end', action='store', default=70000, help='end', type=int)
    parser.add_argument('--seed', action='store', default=30, help='seed', type=int)
    parser.add_argument('--n', action='store', default=1, help='samples', type=int)
 
 
    args = parser.parse_args()
    args = vars(args)

    main(**args)