import os
import json
import subprocess
import glob
import math
from pathlib import Path

# ================= 配置区域 =================

# 1. 路径设置 (请修改为你自己的绝对路径)
PROJECT_ROOT = "/root/autodl-tmp/rStar-rStar-math/train/MCTS"
# INITIAL_MODEL_PATH = "/root/autodl-tmp/rStar-rStar-math/train/model/checkpoint-1960" # 初始 RM 路径
# INITIAL_MODEL_PATH = "/root/autodl-tmp/rStar-rStar-math/qwen/Qwen2.5-3B-Instruct" 
# INITIAL_MODEL_PATH = "/root/autodl-tmp/rStar-rStar-math/train/qwen/Qwen2.5-1.5B-Instruct" 
INITIAL_MODEL_PATH = "/root/autodl-tmp/rStar-rStar-math/train/qwen/Qwen2.5-0.5B-Instruct" 
# ALL_TASKS_FILE = "/root/autodl-tmp/rStar-rStar-math/train/MCTS/graph_data_new/gold_paths_verify.json" # 你的总数据集
ALL_TASKS_FILE = "/root/autodl-tmp/rStar-rStar-math/train/MCTS/gold_paths_all_gold.json"
# ALL_TASKS_FILE = "/root/autodl-tmp/rStar-rStar-math/train/MCTS/graph_data_new/DroidTask/gold_paths_droidtask.json"
OUTPUT_ROOT = "/root/autodl-tmp/rStar-rStar-math/train/MCTS/output/iterative_training_v1"
FROM_PRETRAINED = "False"  # 是否从预训练模型开始训练

# 2. 迭代控制
NUM_ITERATIONS = 6          # 总迭代轮数
TASKS_PER_ITER = 158        # 每轮跑多少个 Task (根据你的总数据量调整)
ACCUMULATE_DATA = True      # True: 每次训练都带上之前所有轮的数据 (防遗忘)

# 3. MCTS 超参 (会透传给 rollout.py)
MCTS_PARAMS = {
    "max_depth": 60,        # 搜索深度上限 (对应你的 path 长度)
    "max_iterations": 50,   # 【关键】搜索次数。次数越多，负例挖掘越充分，训练越慢。建议 30-50。
    "exploration_constant": 3.0,
    "top_k": 5              # 搜索结束后取 top-k 路径做分析
}

# 4. 训练超参 (会透传给 train_regression.py)
TRAIN_PARAMS = {
    "num_epochs": 3,
    "lr": "5e-6",
    "batch_size": 4,        # A800 显存大，单卡 batch 可以给大点
    "grad_acc": 4,          # 累积步数
    "save_limit": 2
}

# ===========================================

def run_cmd(cmd, env=None):
    print(f"\n >>> Running: {cmd}\n")
    # 使用 shell=True 在 Windows 下可能需要注意路径转义，但在 Linux/WSL 下通常没问题
    # 如果是在 Windows Powershell 下跑，路径分隔符最好用 /
    ret = subprocess.call(cmd, shell=True, env=env)
    if ret != 0:
        raise RuntimeError(f"Command failed with return code {ret}")

def main():
    abs_output_root = os.path.abspath(OUTPUT_ROOT) 
    os.makedirs(abs_output_root, exist_ok=True)
    # os.makedirs(OUTPUT_ROOT, exist_ok=True)
    
    # 0. 准备初始任务池
    with open(ALL_TASKS_FILE, 'r', encoding='utf-8') as f:
        all_tasks_raw = json.load(f)
    # 过滤一下 (这是你的逻辑)
    all_tasks = [t for t in all_tasks_raw if t.get("success") == "Yes"]
    print(f"Total valid tasks: {len(all_tasks)}")

    current_model = os.path.abspath(INITIAL_MODEL_PATH)
    # current_model = INITIAL_MODEL_PATH
    cumulative_files = []

    for iteration in range(NUM_ITERATIONS):
        print(f"\n{'='*10} Iteration {iteration+1} / {NUM_ITERATIONS} {'='*10}")
        iter_dir = os.path.join(OUTPUT_ROOT, f"iter_{iteration+1}")
        os.makedirs(iter_dir, exist_ok=True)

        # --- A. 采样任务 (Task Batching) ---
        # 简单切片，轮询使用数据
        start_idx = (iteration * TASKS_PER_ITER) % len(all_tasks)
        end_idx = min(start_idx + TASKS_PER_ITER, len(all_tasks))
        current_batch = all_tasks[start_idx:end_idx]
        
        task_file = os.path.join(iter_dir, "batch_tasks.json")
        with open(task_file, 'w', encoding='utf-8') as f:
            json.dump(current_batch, f, indent=2, ensure_ascii=False)

        # --- B. 执行 Rollout (MCTS) ---
        # 这是单卡运行，因为 Python 脚本一般不支持多卡推理。
        # A800 单卡跑 3B 模型非常快。
        rollout_data_path = os.path.join(iter_dir, "train_data.jsonl")
        
        cmd_rollout = (
            f"python rollout.py "
            f"--model_path \"{current_model}\" "
            f"--task_file \"{task_file}\" "
            f"--output_file \"{rollout_data_path}\" "
            f"--max_iterations {MCTS_PARAMS['max_iterations']} "
            f"--max_depth {MCTS_PARAMS['max_depth']} "
            f"--exploration_constant {MCTS_PARAMS['exploration_constant']} "
            f"--test {True} "
            f"--initial_type {FROM_PRETRAINED} "
            f"--iter {str(iteration)} "
        )
        
        # Windows/Linux 环境通用设置，确保 CUDA 可见
        env = os.environ.copy()
        env["CUDA_VISIBLE_DEVICES"] = "0" # Rollout只用第一张卡
        
        run_cmd(cmd_rollout, env)

        # --- C. 准备训练数据 ---
        if ACCUMULATE_DATA:
            cumulative_files.append(rollout_data_path)
            # 物理合并文件
            final_train_path = os.path.join(iter_dir, "merged_train.jsonl")
            with open(final_train_path, 'w', encoding='utf-8') as outfile:
                for fname in cumulative_files:
                    with open(fname, 'r', encoding='utf-8') as infile:
                        outfile.write(infile.read())
        else:
            final_train_path = rollout_data_path

        # --- D. 执行训练 (Training) ---
        # 双卡训练，使用 accelerate launch
        model_save_dir = os.path.join(iter_dir, "checkpoint")
        model_save_dir = os.path.abspath(model_save_dir)  
        # model_save_dir = os.path.join(iter_dir, "checkpoint")
        abs_train_path = os.path.abspath(final_train_path)
        # 如果是 Windows，accelerate launch 的格式可能略有不同，建议在 WSL 或 Linux 下跑
        # 下面是适配 Linux/WSL 的命令
        cmd_train = (
            f"accelerate launch --num_processes=2 train_regression.py "
            f"--model_name_or_path \"{current_model}\" "
            f"--train_data_path \"{abs_train_path}\" "
            f"--output_dir \"{model_save_dir}\" "
            f"--initial_type {FROM_PRETRAINED} "
            f"--iter {str(iteration)} "
            f"--per_device_train_batch_size {TRAIN_PARAMS['batch_size']} "
            f"--gradient_accumulation_steps {TRAIN_PARAMS['grad_acc']} "
            f"--num_train_epochs {TRAIN_PARAMS['num_epochs']} "
            f"--learning_rate {TRAIN_PARAMS['lr']} "
            f"--evaluation_strategy steps "
            f"--eval_steps 50 "
            f"--save_total_limit {TRAIN_PARAMS['save_limit']} "
            f"--logging_steps 1 "
            f"--bf16 "
            # A800 必须开启 Flash Attention 2
            f"--attn_impl flash_attention_2 " 
            f"--save_strategy epoch "
        )
        
        run_cmd(cmd_train)
        
        # --- E. 更新模型路径 ---
        # train_regression.py 训练完会 save_model 到 output_dir
        current_model = model_save_dir
        print(f"Iteration {iteration+1} Completed. New model: {current_model}")

if __name__ == "__main__":
    main()