import vllm
import torch
from transformers import AutoTokenizer
import argparse
from typing import List
from vllm.outputs import RequestOutput
from evaluation.datasets_loader import get_dataset_handler
import json
import regex as re
import os
STORAGE_PATH = os.getenv("STORAGE_PATH")

def extract_boxed(text):
    results, i = [], 0
    prefix = r'\boxed{'
    plen = len(prefix)

    while True:
        start = text.find(prefix, i)
        if start == -1:
            break   # no more \boxed{…}

        j = start + plen
        depth = 1
        while j < len(text) and depth:
            if text[j] == '{':
                depth += 1
            elif text[j] == '}':
                depth -= 1
            j += 1

        results.append(text[start + plen : j - 1])
        i = j

    return results

def get_response_mask(response_ids, eos_token_id, dtype):
    batch_size, seq_len = response_ids.shape
    mask = torch.ones((batch_size, seq_len), dtype=dtype)
    for i in range(batch_size):
        for j in range(seq_len):
            if response_ids[i][j] == eos_token_id:
                mask[i][j:] = 0
                break
    return mask

def main(args):
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    model = vllm.LLM(
        model=args.model,
        tokenizer=args.model,
        # gpu_memory_utilization=0.8,
        seed=int(args.suffix),
    )
    dataset_handler = get_dataset_handler("math")
    questions, answers = dataset_handler.load_data()
    question = questions[0]
    answer = answers[0]
    chat = [
        {
            "role": "system",
            "content": (
                "You are an expert competition-math problem setter.\n"
                "FIRST, in your private scratch-pad, think step-by-step to design a brand-new, non-trivial problem. "
                "The problem could come from any field of mathematics, including but not limited to algebra, geometry, number theory, combinatorics, prealgebra, probability, statistics, and calculus. "
                "Aim for a difficulty such that fewer than 30 % of advanced high-school students could solve it. "
                "Avoid re-using textbook clichés or famous contest problems.\n"
                "THEN, without revealing any of your private thoughts, output **exactly** the following two blocks:\n\n"
                "<question>\n"
                "{The full problem statement on one or more lines}\n"
                "</question>\n\n"
                r"\boxed{final_answer}"
                "\n\n"
                "Do NOT output anything else—no explanations, no extra markup."
            )
        },
        {
            "role": "user",
            "content": (
                "Generate one new, challenging reasoning question now. "
                "Remember to format the output exactly as instructed."
            )
        }
    ]

    if tokenizer.chat_template:
        prompt = tokenizer.apply_chat_template(
            chat, 
            tokenize=False,
            add_generation_prompt=True, 
            add_special_tokens=True
        )
    else:
        prompt = "system: " + chat[0]["content"] + '\n' + "user: " + chat[1]["content"]
    sample_params = vllm.SamplingParams(
        max_tokens=4096,
        temperature=1.0,
        top_p=0.95,
        n=1,
        stop_token_ids=[tokenizer.eos_token_id],
    )

    completions: List[RequestOutput] = model.generate([prompt]*args.num_samples, sampling_params=sample_params)
    
    # 统计计数器
    total_generated = len(completions)
    success_count = 0  # 成功提取 question 和 answer
    no_question_count = 0  # 无法提取 <question> 标签
    no_answer_count = 0  # 无法提取 \boxed{} 答案
    no_both_count = 0  # 两者都无法提取
    exception_count = 0  # 异常失败
    
    print(f"\n{'='*60}")
    print(f"[GPU {args.suffix}] 问题生成统计")
    print(f"{'='*60}")
    print(f"[GPU {args.suffix}] 总共生成样本数: {total_generated}")
    
    results=[]
    for completion in completions:
        response = completion.outputs[0].text
        try:
            questions = re.findall(r"<question>(.*?)</question>", response, re.DOTALL)
            answers = extract_boxed(response)

            if questions and answers:
                question = questions[-1].strip()
                answer = answers[-1].strip()
                results.append({"question": question, "answer": answer, "score": 0})
                success_count += 1
            else:
                results.append({"question": response, "answer": "", "score": -1})
                # 详细统计失败原因
                if not questions and not answers:
                    no_both_count += 1
                elif not questions:
                    no_question_count += 1
                else:
                    no_answer_count += 1
        except Exception as e:
            results.append({"question": response, "answer": "", "score": -1})
            exception_count += 1
    
    # 打印统计结果
    print(f"\n[GPU {args.suffix}] 过滤统计:")
    print(f"  ✓ 成功提取 (score=0): {success_count} ({100*success_count/total_generated:.1f}%)")
    print(f"  ✗ 失败 (score=-1): {total_generated - success_count} ({100*(total_generated - success_count)/total_generated:.1f}%)")
    print(f"    - 无法提取 <question>: {no_question_count}")
    print(f"    - 无法提取 \\boxed{{}}: {no_answer_count}")
    print(f"    - 两者都无法提取: {no_both_count}")
    print(f"    - 异常错误: {exception_count}")
    print(f"{'='*60}\n")
    
    with open(f"{STORAGE_PATH}/generated_question/{args.save_name}_{args.suffix}.json", "w") as f:
        json.dump(results, f, indent=4)
    
    print(f"[GPU {args.suffix}] 结果已保存到: {STORAGE_PATH}/generated_question/{args.save_name}_{args.suffix}.json")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="Qwen/Qwen3-4B")
    parser.add_argument("--num_samples", type=int, default=1250, help="Number of samples to generate")
    parser.add_argument("--suffix", type=str, default="", help="Suffix to add to the output file")
    parser.add_argument("--save_name", type=str, default="", help="")
    args = parser.parse_args()

    main(args) 