import argparse
import json
import os
import random
import time
import uuid

import sys
from transformers import AutoModelForCausalLM, AutoTokenizer

MAX_INT = sys.maxsize
INVALID_ANS = "[invalid]"

def remove_boxed(raw_response):
    s = last_boxed_only_string(raw_response)
    left = "\\boxed{"
    try:
        assert s[:len(left)] == left
        assert s[-1] == "}"
        return s[len(left):-1]
    except:
        return None

def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx == None:
        retval = None
    else:
        retval = string[idx:right_brace_idx + 1]

    return retval

def batch_data(data_list, batch_size=1):
    n = len(data_list) // batch_size
    batch_data = []
    for i in range(n-1):
        start = i * batch_size
        end = (i+1)*batch_size
        batch_data.append(data_list[start:end])

    last_start = (n-1) * batch_size
    last_end = MAX_INT
    batch_data.append(data_list[last_start:last_end])
    return batch_data

def pred_ans_extract(pred_response):
    if "<answer>" in pred_response:
        ans_1 = pred_response.split('<answer>')[-1]
        ans = ans_1.split("</answer>")[0]
    else:
        if '</think>' in pred_response:
            ans = pred_response.split('</think>')[-1]
        else:
            ans = pred_response

    if '<!-- Final Answer -->' in ans:
        ans = ans.split('<!-- Final Answer -->')[1].strip()
        ans = ans.split('<!-- Finish Answer -->')[0].strip()
    elif '<!-- Summarize Answer -->' in ans:
        ans = ans.split('<!-- Summarize Answer -->')[-1].strip()
        ans = ans.split('<!-- Finish Response -->')[0].strip()
    else:
        if "最终答案" in ans:
            ans = ans.split(r"最终答案")[-1].replace("：", "").replace("**", "").strip()
        elif "**答案**" in ans:
            ans = ans.split(r"**答案**")[-1].replace("：", "").replace("**", "").strip()
        elif "Final Answer" in ans:
            ans = ans.split(r"Final Answer")[-1].replace(":", "").replace("**", "").strip()
        elif "答案：" in ans:
            ans = ans.split(r"答案：")[-1].replace("：", "").replace("**", "").strip()
        elif "**Answer**" in ans:
            ans = ans.split(r"**Answer**")[-1].replace(":", "").replace("**", "").strip()
        elif "Answer:" in ans:
            ans = ans.split(r"Answer:")[-1].replace(":", "").replace("**", "").strip()
        elif "总结" in ans:
            ans = ans.split(r"总结")[-1].replace("：", "").replace("**", "").strip()
        elif "Conclusion" in ans:
            ans = ans.split(r"Conclusion")[-1].replace(":", "").replace("**", "").strip()
        else:
            ans = "\n\n".join(ans.split("\n\n")[-10:])
            ans = ans[-300:].strip()
    if ans != None:
        ans = ans.replace("<|end_of_solution|>", "")
        ans = ans.replace("<|begin_of_solution|>", "")
        ans = ans.replace("<|end_of_thought|>", "")
        ans = ans.replace("<|begin_of_thought|>", "")

    if ans == None:
        ans = pred_response[-300:]
    return ans

def load_data(in_file):
    datas = []
    gidx = 0
    for line in open(in_file, encoding='utf-8'):
        js = json.loads(line)
        if '选择题' in js['tixing']:
            continue
        if js['split'].strip() in ['高中代数', '高中几何', '高中概率', '大学']:
            continue
        js['gidx'] = gidx
        gidx += 1
        datas.append(js)
    return datas


def prompt_format(messages, tokenizer, is_tokenizer=False):
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=is_tokenizer,
        add_generation_prompt=False,
        enable_thinking=False
    )
    return text

def statistic(hendrycks_math_ins, hendrycks_math_ins_tokenizer, k=4000):

    from collections import defaultdict
    BIN_SIZE = k
    bins = defaultdict(list)

    for length, prompt in zip(hendrycks_math_ins_tokenizer, hendrycks_math_ins):
        start = (length // BIN_SIZE) * BIN_SIZE
        end = start + BIN_SIZE
        bucket = f"{start}-{end}"
        bins[bucket].append(prompt)

    for bucket in sorted(bins, key=lambda x: int(x.split('-')[0])):
        print(f"{bucket} token: {len(bins[bucket])} prompt")
    return bins

def split_text(solution, token_nums, k=3000):

    blocks = token_nums // k + 1
    if blocks == 0:
        return [solution]
    lines_lists = solution.split("\n")
    total_lines = len(lines_lists)
    block_line = total_lines // blocks
    if block_line == 0:
        return [solution]
    all_text_lists = []
    for i in range(0, total_lines, block_line):
        text_kk = '\n'.join(lines_lists[i: i+block_line])
        all_text_lists.append(text_kk)
    return_texts = all_text_lists[0:-2]
    last_2 = "\n".join(all_text_lists[-2:])
    temp_texts = []
    temp_texts.extend(return_texts)
    temp_texts.append(last_2)
    return temp_texts

def synthesis_prompt(model, data_path, start, end, batch_size, tensor_parallel_size, args):

    system_prompt = """You are a professional assistant with expertise in mathematics and Python programming, and a multiple gold medalist in mathematics olympiads and programming competitions. You have the capability to write code and use a code interpreter for calculation. The code will be executed in a sandbox environment, and the results will enhance the reasoning process.

## Task Objective

Enhance the provided mathematical problem-solving process by replacing complex manual calculations with Python code and their execution results, while preserving the original reasoning logic and structure.

## Instructions

### 1. Identify Computational Steps for Code Replacement

Identify steps that would benefit from code execution, including:
- **Complex symbolic algebra**: polynomial expansion, factorization, solving equations
- **Advanced calculus**: differentiation, integration, evaluating definite integrals
- **Probability and combinatorics**: complex counting, probability distributions
- **Linear algebra**: matrix operations, inversion, eigenvalue decomposition
- **Numerical computations**: approximations, large number calculations, geometric calculations
- **Any error-prone or computationally intensive calculations**

**Important**: Simple calculations (basic arithmetic like 2×3=6) should remain as text. Only use code for complex computations that require at least 5 lines of implementation.

### 2. Code Implementation Requirements

Each code snippet must:
- Be complete and executable, including all necessary imports
- Use appropriate libraries (`sympy`, `numpy`, `scipy`, etc.)
- Include clear variable definitions and comments
- Explicitly use `print()` for all outputs
- Demonstrate the computation process, not just final results
- Contain at least 5 lines of code

### 3. Integration Guidelines

- **Preserve the original reasoning flow**: Keep all logical steps, explanations, and even failed attempts unchanged
- **Seamless integration**: Code and results should naturally fit within the surrounding text
- **Context preservation**: Maintain semantic coherence and logical consistency
- **No unnecessary changes**: Do not modify, delete, or polish unrelated content

### 4. Formatting Requirements

Wrap each code snippet as follows:
<code>
```python
# Code implementation
import ...
# computation steps
print(...)
```
</code>

Follow immediately with execution results:
<interpreter>
execution results here
</interpreter>

## Input Format

<original_thinking_process>
{original_response}
</original_thinking_process>

## Output Format

Provide the enhanced thinking process wrapped in:
<revised_thinking_process>
[Enhanced solution with code snippets and execution results integrated]
</revised_thinking_process>

## Key Principles

1. **Code Necessity**: Use code only for complex calculations that warrant automation
2. **Minimal Changes**: Modify only the computational steps, preserving all other text
3. **Complete Scripts**: Each code block must be self-contained and executable
4. **Educational Value**: Code should demonstrate the computation process clearly
5. **Accuracy**: Execution results must match exactly what the code produces
6. **"No Code Needed" Principle**: For simple calculations (such as `2 * 3 = 6` or `10 / 2 = 5`, etc.), there is no need to generate a code block. Simply retain or write out the result directly. The four basic arithmetic operations—addition, subtraction, multiplication, and division—do not require the use of code for calculation. The number of lines of code must exceed five.
7. **"Code Needed" Principle**: When the math involved is complex, error-prone, high in difficulty, computationally intensive, or inconvenient to perform manually, code should be used for assistance. This includes, but is not limited to: complex symbolic algebra (e.g., polynomial expansion, factorization, solving equations, etc.), complex advanced calculus operations (e.g., differentiation, integration, evaluating definite integrals, etc.), complex intricate probability and combinatorial analysis, complex advanced linear algebra operations (e.g., matrix calculations, inversion, eigenvalue decomposition, etc.), as well as complex numerical approximations, complex large number computations, complex geometric calculations, or other complex logical operations that are not suitable for manual calculation, etc..
9. **Review and Verify**: Carefully examine the revised thinking process to ensure that every code computation meets the required complexity standards and that each code block contains at least five lines.
"""
    print('promt =====', system_prompt)
    model_name = "Qwen/Qwen3-8B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    text_path = "a-m-team/AM-DeepSeek-R1-0528-Distilled/math.jsonl" #  we download the open-source data from huggingface: a-m-team/AM-DeepSeek-R1-0528-Distilled

    exist_call_data_dict = {}


    all_text_datas = []
    with open(text_path, 'r') as file:
        for idx, line in enumerate(file):
            if idx % 10000 == 0:
                print(f'idx===={idx}')
            item = json.loads(line)
            ques_raw = item['instruction']
            if ques_raw in exist_call_data_dict:
                continue
            all_text_datas.append(item)

    random_sample_2w = random.sample(all_text_datas, k=60000)
    start_time = time.time()
    texts_to_tokenize_lists = [item['output'].split("</think>")[0].split('<think>')[-1] for item in random_sample_2w]
    enc_tokenizer_results = tokenizer( texts_to_tokenize_lists, add_special_tokens=False, padding=False, truncation=False,)
    end_time = time.time()
    print(f"enc_tokenizer_results time ==== {end_time - start_time}")
    call_datas = []
    all_tokens = []
    for item, token_ids in zip(random_sample_2w, enc_tokenizer_results["input_ids"]):
        uid = str(uuid.uuid4())
        item['uid'] = uid
        instruction = item['instruction']
        think_process = item['output'].split("</think>")[0].split('<think>')[-1]
        token_nums = len(token_ids)
        all_text_lists = split_text(think_process, token_nums, k=3000)
        all_tokens.append(token_nums)
        all_instructions = []
        for text_sp in all_text_lists:
            call_instruction = system_prompt.format(original_response=text_sp)
            all_instructions.append(call_instruction)
        for turn, sys_instr in enumerate(all_instructions):
            temp = {}
            temp['uid'] = uid
            temp['instruction'] = instruction
            temp['question'] = sys_instr
            temp['turn'] = turn
            temp['total_turns'] = len(all_instructions)
            temp['total_token_nums'] = token_nums
            call_datas.append(temp)

    save_path_dir = f'datas/save_synthesis_response_prompt_dir'
    os.makedirs(save_path_dir, exist_ok=True)

    len_datas = len(call_datas)
    split_num = 3
    slice_nums = len_datas // 3 + 1
    for i in range(split_num):
        st = i*slice_nums
        ed = (i+1)*slice_nums
        split_call_datas = call_datas[st: ed]
        save_path = f'{save_path_dir}/split_{i+1}.json'
        with open(save_path, 'w', encoding='utf-8') as f:
            for item in split_call_datas:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')

        print(f'save_path======{save_path}, \n==========infer done==============')

    save_path_text = f'{save_path_dir}/text_data.json'
    with open(save_path_text, 'w', encoding='utf-8') as f:
        for item in random_sample_2w:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

    print(f'save_path======{save_path_text}, \n==========infer done==============')
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default=0)  # model path
    parser.add_argument("--data_file", type=str, default='data/MATH_test.jsonl')  # data path
    parser.add_argument("--start", type=int, default=0) #start index
    parser.add_argument("--end", type=int, default=MAX_INT)  # end index
    parser.add_argument("--batch_size", type=int, default=50)  # batch_size
    parser.add_argument("--tensor_parallel_size", type=int, default=1)  # tensor_parallel_size
    parser.add_argument("--max_tokens", type=int, default=16384)  # max_tokens
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    synthesis_prompt(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size, args=args)


'''

python data_synthesis/synthesis_prompt.py --model Qwen/Qwen3-8B --data_file /root/datas/AM-DeepSeek-R1-0528-Distilled/math.jsonl --batch_size 5000 --tensor_parallel_size 1 --max_tokens 16384

'''
