from vllm import LLM, SamplingParams
import json
import random
from pathlib import Path

def ours_(llm,data_path,if_train=False):

    sampling_params = SamplingParams(temperature=1,top_p=0.95,max_tokens=1024)
    p = Path(data_path)
    filename = p.stem
    with open(data_path,'r',encoding='utf-8') as f:
        datas = json.load(f)

    random.seed(42)
    selected_samples = random.sample(datas, 128)

    question_list = []
    if not if_train:
        for item in selected_samples:
            prompt = item['question']
            question_list.append(prompt)
    elif if_train:
        for item in selected_samples:
            prompt = item['instruction']
            question_list.append(prompt)

    adv_response = []
    for i in range(0,len(question_list),16):
        batch = question_list[i:i+16]
        outputs = llm.generate(batch,sampling_params)
        for output in outputs:
            generated_text = output.outputs[0].text
            adv_response.append(generated_text)
    
    new_data = []
    if not if_train:
        for data, new in zip(selected_samples,adv_response):
            if data['good_response'] == 'A':
                data['response_A'] = new
            elif data['good_response'] == 'B':
                data['response_B'] = new
    elif if_train:
        for data, new in zip(selected_samples,adv_response):
            data['chosen'] = new
    
    save_path = f'/data/{filename}.json'
    with open(save_path,'w',encoding='utf-8') as f:
        json.dump(selected_samples,f,indent=2,ensure_ascii=False)


if __name__ == '__main__':
    model_path = '/path/to/model'
    llm = LLM(model=model_path)
    data_list = [
        '/path/to/data'
    ]
    
    for data_path in data_list:
        ours_(llm,data_path)
    
    ours_(llm,'/path/to/data',True)

