import os
import json
import argparse
import asyncio
from typing import Optional, Dict, Any, List
from tqdm.asyncio import tqdm
from openai import AsyncOpenAI

# ================= 配置区域 =================
SYSTEM_PROMPT = """You are a data conversion assistant. Your task is to convert "Action-Centric Structured Reasoning" (PRIME format) into "Standard Natural Language Chain-of-Thought" (ProcessBench/CoT format).

**Goal:**
Rewrite the input solution to remove all structural tags and formatting artifacts, resulting in a smooth, continuous mathematical reasoning narrative.

**Strict Transformation Rules:**
1.  **Remove Structure:** -   Delete ALL Action Tags: `[ASSESS]`, `[ADVANCE]`, `[VERIFY]`, `[SIMPLIFY]`, `[SYNTHESIZE]`, `[PIVOT]`, `[OUTPUT]`.
    -   Delete ALL Navigation Tags: `Next action: [...]`.
    -   Delete ALL Bullet points/Prefixes: Remove the `#` symbol at the start of reasoning lines.

2.  **Reconstruct Narrative:**
    -   Merge the fragmented lines (originally starting with `#`) into coherent paragraphs.
    -   Add natural transitional phrases if necessary (e.g., "First,", "Then,", "Next,", "Consequently,", "Finally,") to maintain flow where the Action Tags used to be.
    -   The final output should look like a standard textbook solution or a human-written CoT.

3.  **Preserve Content:**
    -   **DO NOT** change the mathematical logic, numbers, or calculations.
    -   **DO NOT** change the LaTeX formatting (e.g., keep `\frac{...}`, `\boxed{...}`).
    -   Ensure the final answer `\boxed{...}` remains at the end.

**Example Input:**
[ASSESS]
# We need to solve for x.
Next action: [ADVANCE]
[ADVANCE]
# Subtracting 5 from both sides gives x = 2.
Next action: [OUTPUT]
[OUTPUT]
The answer is \boxed{2}.

**Example Output:**
To solve this problem, we first need to isolate x. Subtracting 5 from both sides of the equation gives x = 2. Therefore, the answer is \boxed{2}.
"""

# Mock data loader
def load_single_dataset(data_path):
    import datasets
    if data_path.endswith('.jsonl') or data_path.endswith('.json'):
        return datasets.load_dataset('json', data_files=data_path, split='train')
    else:
        return datasets.load_dataset(data_path, split='train')

# ===========================================

async def generate_paraphrase(
    client: AsyncOpenAI, 
    text: str, 
    model_name: str, 
    semaphore: asyncio.Semaphore
) -> Optional[str]:
    """
    针对单条文本进行改写的辅助函数
    """
    user_content = f"**Input Trajectory:**\n{text}"
    
    async with semaphore:
        try:
            response = await client.chat.completions.create(
                model=model_name,
                messages=[
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": user_content}
                ],
                temperature=0.7,
                top_p=0.9,
                max_tokens=4096 
            )
            paraphrased_text = response.choices[0].message.content.strip()

            # 简单的清洗检查
            cleaned = paraphrased_text.replace("```markdown", "").replace("```", "").strip()
            if cleaned.startswith("[ASSESS]"):
                paraphrased_text = cleaned
            
            return paraphrased_text
        except Exception as e:
            print(f"Error generating paraphrase: {e}")
            return None

async def process_single_sample(
    client: AsyncOpenAI, 
    sample: Dict[str, Any], 
    model_name: str,
    semaphore: asyncio.Semaphore
) -> Optional[Dict[str, Any]]:
    """
    处理单个样本（包含 response list）
    """
    # 确保是列表，防止原始数据格式不统一
    original_responses = sample.get('responses', [])
    if isinstance(original_responses, str):
        original_responses = [original_responses]
        
    original_scores = sample.get('scores', [])
    original_reasons = sample.get('finish_reasons', [])

    # 1. 为列表中的每一条回复创建改写任务
    tasks = []
    task_indices = [] # 记录哪些索引正在被改写，用于后续对其 score

    for idx, resp_text in enumerate(original_responses):
        # 创建异步任务
        tasks.append(generate_paraphrase(client, resp_text, model_name, semaphore))
        task_indices.append(idx)

    # 2. 并行执行改写
    # return_exceptions=True 防止一个失败导致全部崩溃，但这里我们要在下面手动处理 None
    results = await asyncio.gather(*tasks)

    # 3. 组装结果
    # 只需要成功的改写结果
    new_responses = []
    new_scores = []
    new_reasons = []

    for i, res in enumerate(results):
        if res is not None:
            new_responses.append(res)
            new_reasons.append("rewrite")
            
            # 对应的 score 应该与原回复相同
            # 注意：需防止 original_scores 长度不够（虽然理论上应该对齐）
            original_idx = task_indices[i]
            if original_idx < len(original_scores):
                new_scores.append(original_scores[original_idx])
            else:
                new_scores.append(None) # 或者 0.0，视具体业务逻辑而定

    # 4. 如果没有生成任何有效的改写（比如都不含 ADVANCE 或都报错），可以选择返回 None 跳过该样本，
    #    或者返回原样本。这里根据你的需求，如果不产生改写，就不保存这个样本了，节省空间。
    if not new_responses:
        return None

    # 5. 更新字段
    # response: [原回复列表] + [新回复列表]
    # scores:   [原分数列表] + [新分数列表]
    # reasons:  [原原因列表] + [新原因列表]
    sample['responses'] = original_responses + new_responses
    sample['scores'] = original_scores + new_scores
    
    # 兼容原数据可能没有 finish_reasons 的情况
    if original_reasons:
        sample['finish_reasons'] = original_reasons + new_reasons
    else:
        # 如果原数据没有，就只放新的，或者补全 "stop" (视情况而定，这里补全 rewrite)
        sample['finish_reasons'] = ["unknown"] * len(original_responses) + new_reasons

    sample['is_paraphrased'] = True
    
    return sample

async def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path",   type=str, required=True, help="Path to the input dataset")
    parser.add_argument("--save_path",   type=str, required=True, help="Path to save the .jsonl output")
    parser.add_argument("--api_key",     type=str, default="sk-xxxx", help="OpenAI API Key")
    parser.add_argument("--base_url",    type=str, default="https://api.deepseek.com", help="API Base URL")
    parser.add_argument("--model",       type=str, default="deepseek-chat", help="Model name")
    parser.add_argument("--num_samples", type=int, default=2048, help="Number of samples to process")
    parser.add_argument("--concurrency", type=int, default=16, help="Max concurrent API calls (higher since we gather inside)")
    
    args = parser.parse_args()

    client = AsyncOpenAI(api_key=args.api_key, base_url=args.base_url)

    print(f"Loading dataset from {args.data_path}...")
    ds = load_single_dataset(args.data_path)
    print(f"Dataset loaded. Total rows: {len(ds)}")

    # 2. 筛选数据
    # 逻辑修改：只要列表里 至少有一条 包含 [ADVANCE]，我们就保留这个样本进行处理
    target_samples = []
    count = 0
    for item in ds:
        res_list = item.get('responses', [])
        if isinstance(res_list, str): res_list = [res_list]
        
        # 只要有一条包含 ADVANCE 就可以改写
        if any("[ADVANCE]" in r for r in res_list):
            target_samples.append(item)
            count += 1
            if count >= args.num_samples:
                break
    
    print(f"Selected {len(target_samples)} samples containing [ADVANCE].")

    # 3. 并发处理
    semaphore = asyncio.Semaphore(args.concurrency)
    
    tasks = [
        process_single_sample(client, sample, args.model, semaphore) 
        for sample in target_samples
    ]

    results = []
    for f in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Processing Samples"):
        res = await f
        if res:
            results.append(res)

    # 4. 保存结果
    print(f"Saving {len(results)} results to {args.save_path}...")
    with open(args.save_path, 'w', encoding='utf-8') as f:
        for item in results:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

    print("Done!")

if __name__ == "__main__":
    asyncio.run(main())



"""

~/verl_cs/.conda/bin/python ~/verl_cs/scripts/mot1_rewrite.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/validation_0_2048_pairwise.json \
    --save_path ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/validation_0_2048_rewrite.json \
    --api_key sk-3e37831a511c4457a9fa6d4edd4bd7cc \
    --model deepseek-chat \
    --base_url https://api.deepseek.com

~/verl_cs/.conda/bin/python ~/verl_cs/scripts/mot1_rewrite.py \
    --data_path ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/validation_0_2048_pairwise.json \
    --save_path ~/LLaMA-Factory-250514/saves_shuyan/llama3.2-3B/prime-sft/prime-rl-rollouts/validation_0_2048_rewrite1.json \
    --api_key sk-3e37831a511c4457a9fa6d4edd4bd7cc \
    --model deepseek-chat \
    --base_url https://api.deepseek.com

"""