import argparse
import jsonlines
from transformers import AutoTokenizer
import numpy as np
from tqdm import tqdm

template = """
# Problem to Solve

{problem}

# Instruction

If you need ask information, please raise clarification question and start your response STRICTLY with: "Clarification Question" followed by your questions.
Otherwise, please reason step by step, and put your final answer within \\boxed{{}}.
"""

thought_template = "<think>\n{thought}\n</think>\n\n{answer}"
normal_template = "{answer}"

def process_data_bug(data):
    for item in data:
        prompt = item['prompt']
        problem = prompt.split('# Problem to Solve')[1].split('# Instruction')[0].strip()
        new_prompt = template.format(problem=problem.lower())
        item['prompt'] = new_prompt
    return data

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file_non_cls', type=str, default=None)
    parser.add_argument('--input_file_cls', type=str, default=None)
    parser.add_argument('--output_file', type=str, required=True)
    parser.add_argument('--include_thought', action='store_true')
    args = parser.parse_args()
    tokenizer = AutoTokenizer.from_pretrained('/data/model_path/models/Qwen3-8B-Base')
    non_cls_data = list(jsonlines.open(args.input_file_non_cls)) if args.input_file_non_cls else []
    cls_data = list(jsonlines.open(args.input_file_cls)) if args.input_file_cls else []
    output_data = []
    cls_data = process_data_bug(cls_data)
    tl_length = []
    for item in tqdm(non_cls_data + cls_data):
        if args.include_thought:
            output_message = [{'role': 'assistant', 'content': thought_template.format(thought=item['thought'], answer=item['answer'])}]
        else:
            output_message = [{'role': 'assistant', 'content': normal_template.format(answer=item['answer'])}]
        input_message = [{'role': 'user', 'content': item['prompt']}]
        output_data.append({
            'prompt': input_message,
            'response': output_message,
        })
        tl_length.append(
            len(tokenizer.apply_chat_template(input_message+output_message))-
            len(tokenizer.apply_chat_template(input_message, add_generation_prompt=True))
        )
    print(np.max(tl_length))
    print(np.mean(tl_length))
    with jsonlines.open(args.output_file, 'w') as writer:
        writer.write_all(output_data)


if __name__ == '__main__':
    main()
