import json 
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from reward_model_filter import RewardModel
import os 
import argparse
from tqdm import tqdm

class SampleMaker():
    def __init__(self,model_path,sample_nums):
        os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
        self.llm = LLM(model=model_path,tensor_parallel_size=4)
        self.sample_nums = sample_nums
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.sampling_params = SamplingParams(temperature=1, top_p=0.95, max_tokens=1024,n=sample_nums)
    
    def get_sample(self, queries):
        messages = [
            [{
                "role": "user",
                "content": query,
            }]
            for query in queries
        ]
        outputs = self.llm.chat(messages,self.sampling_params)
        final_data = []
        for query,output in zip(queries,outputs):
            response_list = []
            for j in range(self.sample_nums):
                response = output.outputs[j].text
                response_list.append(response)
            cur = {
                "question":query,
                "response":response_list
            }
            final_data.append(cur)
        return final_data


def make_samples(args):
    print('chosen adv make sample🚀')
    model = SampleMaker(args.model_path,args.sample_nums)
    with open(args.sample_path, 'r', encoding='utf-8') as f:
        datas = json.load(f)

    queries_batch = []
    bs = args.batch_size

    for i in range(0,len(datas),bs):
        batch = datas[i:i + bs]
        queries = [item['instruction'] for item in batch]
        queries_batch.append(queries)
    final_json = []

    for batch in tqdm(queries_batch,desc='sample making process...', total=len(queries_batch)):
        response_json = model.get_sample(queries=batch)
        final_json.extend(response_json)

        with open(args.save_output, 'w', encoding='utf-8') as f:
            json.dump(final_json, f ,ensure_ascii=False, indent=2)
    print("sample finished✅")


def parse_args():
    parser = argparse.ArgumentParser(description='Command - line argument parser for testing')
    parser.add_argument('--sample_nums',default=2)
    parser.add_argument('--model_path',default='llama3.1-8B')
    parser.add_argument('--sample_path')
    parser.add_argument('--save_output', default='./chosen_data')
    parser.add_argument('--batch_size',default=128)
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    args.sample_path = 'path/to/train_data'
    make_samples(args)

        

