from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import os
from tqdm import tqdm

model_name = "Qwen/Qwen2.5-14B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def load_json_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        return json.load(file)

def save_json_file(data, file_path):
    with open(file_path, 'w', encoding='utf-8') as file:
        json.dump(data, file, indent=2, ensure_ascii=False)

def generate_answer(prompt):
    messages = [
        {"role": "system", "content": """You are a precision-focused assistant. Generate concise and factual answers based on the given question and reasoning context. 
            Answer should directly address the question using ONLY information from the reasoning. """},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=1024
    )
    
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    return response.strip()

def main(input_file, output_file):
    data = load_json_file(input_file)
    
    print(f"Processing {len(data)} questions...")
    
    for item in tqdm(data, desc="Generating answers", unit="question", ncols=100):
        prompt = f"Question: {item['question']}\n Reasoning context: {item['model_reasoning_output']}\nA: "
        item['model_answer_output'] = generate_answer(prompt)
    
    save_json_file(data, output_file)
    print(f"Processing completed. Results saved to {output_file}")

if __name__ == "__main__":
    input_file = ""
    output_file = ""
    main(input_file, output_file)