# -*- coding: utf-8 -*-
import vllm
import torch
from transformers import AutoTokenizer
import argparse
from typing import List
from vllm.outputs import RequestOutput
from evaluation.datasets_loader import get_dataset_handler
import json
import regex as re
import os
STORAGE_PATH = os.getenv("STORAGE_PATH")
import random
from mathruler.grader import extract_boxed_content, grade_answer

my_prompt_no_level = r'''
You are an expert curriculum designer and problem setter for an advanced AI agent. 
You are provided with some specific curriculum specifications including a specific subject, an example problem and the reference solution of the problem.
You are encouraged to analyze the problem, brainstorm and propose a brand-new, multi-step reasoning problem which meets the following requirements:
1. The new problem must require the knowledge of the provided subject.
2. The difficulty of the new problem is comparable to the difficulty of the given Example Problem.
3. Avoid re-using textbook clichés or famous contest problems.
4. The new problem MUST NOT be semantically similar to the provided Example Problem.

FIRST, you must complete the following steps:
1. Analyze the specific subject, the example problem and the reference solution
2. Construct a unique, multi-step problem that meets the above requirements.
3. ***CRITICAL VALIDATION: Self-solve the complete problem step-by-step to ensure it is logically consistent, non-ambiguous, and yields a single, verifiable solution.***

FINALLY, output the problem statement and the verified final answer in the following format:

<think>
[Your complete reasoning process including the analysis, problem construction, and step-by-step self-solution validation as described in the three steps above.]
</think>

<question>
[The complete problem statement on one or more lines]
</question>

\boxed{final_answer}

[SPECIFICATIONS]
1. Subject: {subject}
2. Example Problem: {example_problem} 
3. Reference Solution: {reference_solution}

'''



def build_one_prompt(tokenizer, line):
    data_row = json.loads(line)

    cur_prompt = my_prompt_no_level.format(
        final_answer = "{final_answer}",
        subject=data_row['subject'],
        example_problem=data_row['problem'],
        reference_solution=data_row['gpt_response'],
    )

    final_system_content = cur_prompt

    chat = [
        {
            "role": "system",
            "content": final_system_content
        },
        {
            "role": "user",
            "content": (
                "Generate one new, challenging reasoning question now. "
                "Remember to format the output exactly as instructed."
            )
        }
    ]

    if tokenizer.chat_template:
        prompt = tokenizer.apply_chat_template(
            chat,
            tokenize=False,
            add_generation_prompt=True,
            add_special_tokens=True
        )
    else:
        prompt = "system: " + chat[0]["content"] + "\n" + "user: " + chat[1]["content"]

    return prompt



def get_response_mask(response_ids, eos_token_id, dtype):
    batch_size, seq_len = response_ids.shape
    mask = torch.ones((batch_size, seq_len), dtype=dtype)
    for i in range(batch_size):
        for j in range(seq_len):
            if response_ids[i][j] == eos_token_id:
                mask[i][j:] = 0
                break
    return mask

def main(args):
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    model = vllm.LLM(
        model=args.model,
        tokenizer=args.model,
        seed=int(args.suffix),
    )
    
    with open(args.input_path, 'r', encoding='utf-8') as f:
        all_lines = f.readlines()

    num_gpus = 8
    gpu_id = int(args.suffix)

    total = len(all_lines)
    chunk_size = (total + num_gpus - 1) // num_gpus

    start = gpu_id * chunk_size
    end = min(start + chunk_size, total)

    lines = all_lines[start:end]
    print(f"[GPU {gpu_id}] Loaded {len(lines)} lines from {start} to {end}")


    prompts = []
    lines = lines * 8
    for line in lines:
        one_prompt = build_one_prompt(tokenizer, line)
        prompts.append(one_prompt)
    

    sample_params = vllm.SamplingParams(
        max_tokens=4096,
        temperature=0.7,
        stop_token_ids=[tokenizer.eos_token_id],
    )

    completions: List[RequestOutput] = model.generate(prompts, sampling_params=sample_params)
    results=[]
    for completion in completions:
        response = completion.outputs[0].text
        try:
            questions = re.findall(r"<question>(.*?)</question>", response, re.DOTALL)
            answers = extract_boxed_content(response)

            if questions and answers:
                question = questions[-1].strip()
                answer = answers[-1].strip()
                results.append({"question": question, "answer": answer, "score": 0})
            else:
                results.append({"question": response, "answer": "", "score": -1})
        except:
            results.append({"question": response, "answer": "", "score": -1})
    with open(f"{STORAGE_PATH}/generated_question/{args.save_name}_{args.suffix}.json", "w") as f:
        json.dump(results, f, indent=4)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="Qwen/Qwen3-4B-Base")
    parser.add_argument("--num_samples", type=int, default=1250, help="Number of samples to generate")
    parser.add_argument("--suffix", type=str, default="", help="Suffix to add to the output file")
    parser.add_argument("--save_name", type=str, default="", help="")
    parser.add_argument("--input_path", type=str, default="", help="")
    args = parser.parse_args()

    main(args) 