import jsonlines
from vllm import LLM
from vllm import SamplingParams
import argparse
import torch
from transformers import GenerationConfig


template = """
# Problem to Solve

{problem}

# Instruction

If you need ask information, please raise clarification question and start your response STRICTLY with: "Clarification Question" followed by your questions.
Otherwise, please reason step by step, and put your final answer within \\boxed{{}}.
"""


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file', type=str, required=True)
    parser.add_argument('--output_file', type=str, required=True)
    parser.add_argument('--model_path', type=str, required=True)
    parser.add_argument('--use_raw', action='store_true')
    args = parser.parse_args()

    data = list(jsonlines.open(args.input_file))
    if args.use_raw:
        prompts = [template.format(problem=item['raw_task'].lower()) for item in data]
    else:
        prompts = [template.format(problem=item['unclear_task'].lower()) for item in data]
    print(prompts[0])
    messages = [
        [{'role': 'user', 'content': prompt}]
        for prompt in prompts
    ]
    cfg = GenerationConfig.from_pretrained(args.model_path)
    stop_token_ids = cfg.eos_token_id
    if not isinstance(stop_token_ids, list):
        stop_token_ids = [stop_token_ids]
    model = LLM(model=args.model_path, tensor_parallel_size=torch.cuda.device_count(), gpu_memory_utilization=0.85)
    sampling_params = SamplingParams(
        temperature=0., max_tokens=20480, stop_token_ids=stop_token_ids, presence_penalty=1
    )
    responses = model.chat(messages, sampling_params, use_tqdm=True)
    output_data = []
    for prompt, response, item in zip(prompts, responses, data):
        output_data.append({
            'prompt': prompt,
            'response': response.outputs[0].text,
            'metadata': item,
        })
    with jsonlines.open(args.output_file, 'w') as f:
        f.write_all(output_data)


if __name__ == '__main__':
    main()
