import requests
import json
import time
import threading # 恢复 threading
from concurrent.futures import ThreadPoolExecutor # 恢复 ThreadPoolExecutor
import os
# import random # 仍然不需要
import argparse
import openai
# from datetime import datetime # 仍然不需要
import uuid # 用于生成临时文件名
import math # 用于分配任务

# --- 恢复 WorkerIPManager ---
class WorkerIPManager:
    def __init__(self, master_url, update_interval=300):
        self.master_url = master_url # Master URL 用于获取 Worker IP 列表
        self.update_interval = update_interval
        self.worker_ips = []
        self._ips_lock = threading.Lock()
        self._update_thread = None
        self._stop_event = threading.Event()
        self.update_ips() # Initial fetch
        self.start_update_thread()

    def get_worker_ips_from_master(self):
        try:
            # 注意：这里仍然是从 master 获取 worker IP 列表
            response = requests.get(f"{self.master_url}/workers/ips", timeout=10)
            response.raise_for_status()
            ips = response.json().get('ips', [])
            print(f"从 Master 获取到 Worker IPs: {ips}")
            return ips
        except requests.exceptions.RequestException as e:
            print(f"无法从 Master 获取 Worker IPs: {str(e)}")
            return []
        except KeyError:
            print("Master 返回的 Worker IP 列表格式无效")
            return []
        except Exception as e:
            print(f"获取 Worker IP 时发生未知错误: {e}")
            return []

    def update_ips(self):
        new_ips = self.get_worker_ips_from_master()
        with self._ips_lock:
            self.worker_ips = new_ips
            if not self.worker_ips:
                 print("警告: 未获取到任何可用的 Worker IP。")

    def _update_loop(self):
        while not self._stop_event.is_set():
            self._stop_event.wait(self.update_interval)
            if not self._stop_event.is_set():
                print("定时更新 Worker IP 列表...")
                self.update_ips()

    def start_update_thread(self):
        if self._update_thread is None or not self._update_thread.is_alive():
            self._stop_event.clear()
            self._update_thread = threading.Thread(target=self._update_loop, daemon=True)
            self._update_thread.start()
            print(f"Worker IP 自动更新线程已启动 (间隔: {self.update_interval} 秒)")

    def stop_update_thread(self):
        self._stop_event.set()
        if self._update_thread:
            self._update_thread.join()
        print("Worker IP 自动更新线程已停止")


    def get_cached_ips(self):
        with self._ips_lock:
            # 返回当前缓存的 IP 列表副本
            return list(self.worker_ips)

# 移除 call_llm_api 函数 (仍然不需要)

def create_batch_request_item(config, index, model_name="deepseek-ai/DeepSeek-V3", temperature=0.6, max_tokens=32768):
    """为单个配置创建 Batch API 请求项"""
    # Extract the features part from the config
    features_json = json.dumps(config["features"], ensure_ascii=False, indent=4)
    # Normalize newline characters while keeping the JSON structure unchanged
    features_json = features_json.replace("\\n", "\n")

    # Construct the English prompt with updated instructions
    prompt = f"""You are a professional competitive programming problem setter.

Please generate a **"very hard"**, Codeforces-style problem based on the following features.  
The problem must be rigorous, original, and meet all the following difficulty standards.

---

**Difficulty Requirements:**

1. **Algorithmic Complexity**: The problem must require a combination of 2–3 advanced algorithmic techniques.

2. **Thinking Depth**: The problem should rely on discovering hidden properties or structural tricks. Brute-force or direct template solutions must not work.

3. **Time Complexity Optimization**: Brute-force solutions (e.g., O(n²) or worse) should time out. Correct solutions should operate within optimized time bounds (e.g., O(n log n), O(n√n), etc.).

4. **Edge Cases and Pitfalls**: The problem must include easy-to-make mistakes or tricky details (e.g., integer overflow, off-by-one errors, dictionary order traps, etc.).

5. **Implementation Difficulty**: The problem should involve complex state transitions or multi-dimensional data structures.

---

**Scenario Requirements:**

To fit the CodeForces-style, the problem proposed should be embedded in a self-consistent scenario.  
The story should naturally motivate the algorithmic constraints and connect logically to the selected features.

---

**Output Format:**

Return your result strictly in the following JSON format:
{{
    "selected_features": [ ... ],
    "question": (codeforces-stype problem statement)
}}

Provided Features:
{features_json}
"""
    messages = [{"role": "user", "content": prompt}]

    return {
        "custom_id": f"request-{index}",
        "method": "POST",
        "url": "/v1/chat/completions", # Batch API 内部使用的端点
        "body": {
            "model": model_name,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
        }
    }


def read_jsonl(file_path):
    """从JSONL文件读取数据"""
    results = []
    if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
        with open(file_path, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    results.append(json.loads(line.strip()))
                except json.JSONDecodeError:
                    print(f"警告：跳过无效的JSONL行 in {file_path}")
    return results

# --- 修改 append_to_jsonl 以使用锁 ---
def append_to_jsonl(file_path, data, lock):
    """将数据追加到JSONL文件（线程安全）"""
    with lock:
        with open(file_path, "a", encoding="utf-8") as f:
            f.write(json.dumps(data, ensure_ascii=False) + "\n")

# --- 修改 process_batch_result_line 以使用锁 ---
def process_batch_result_line(batch_line_json, result_file, output_lock):
    """处理从 Batch API 结果文件读取的单行 JSON（线程安全写入）
       只有完全成功的任务才写入文件，任何错误都会导致任务被跳过（以便下次重试）。
    """
    custom_id = batch_line_json.get("custom_id")
    response_data = batch_line_json.get("response")
    error_data = batch_line_json.get("error")

    # 首先尝试从 custom_id 获取原始任务索引
    try:
        index = int(custom_id.split("-", 1)[1])
    except (IndexError, ValueError, TypeError):
        print(f"警告: 无法从 custom_id '{custom_id}' 提取索引。跳过此结果行: {batch_line_json}")
        return # 跳过，不保存

    # 检查顶层错误 (OpenAI Batch API 本身报告的错误)
    if error_data:
        print(f"Worker Batch (ID: {custom_id}): 任务 {index} 失败 (Batch API 错误): {error_data}。跳过此任务，将在下次运行时重试。")
        return # 跳过，不保存

    # 检查响应数据是否存在
    if not response_data:
        print(f"警告: Worker Batch (ID: {custom_id}): 任务 {index} 的结果行缺少 'response' 字段。跳过此任务，将在下次运行时重试。")
        return # 跳过，不保存

    # 检查响应状态码和 body
    status_code = response_data.get("status_code")
    response_body = response_data.get("body")

    if status_code != 200 or not response_body:
        error_message = f"Batch API Response Error (Status: {status_code}, Body: {response_body})"
        print(f"Worker Batch (ID: {custom_id}): 任务 {index} 失败: {error_message}。跳过此任务，将在下次运行时重试。")
        return # 跳过，不保存

    # --- 开始处理成功的 LLM 响应体 ---
    llm_content = None
    try:
        choices_data = response_body.get("choices")
        choice = None

        if isinstance(choices_data, list):
            if not choices_data:
                 raise ValueError(f"Response body for task {index} (ID: {custom_id}) has empty 'choices' list.")
            choice = choices_data[0]
        elif isinstance(choices_data, dict):
            choice = choices_data
        else:
             raise ValueError(f"Response body for task {index} (ID: {custom_id}) has missing or invalid 'choices' field (expected non-empty list or dict, got: {type(choices_data)}).")

        if choice is None:
             raise ValueError(f"Could not determine a valid 'choice' object for task {index} (ID: {custom_id}).")

        message = choice.get("message")
        if not message:
            raise ValueError(f"Choice object for task {index} (ID: {custom_id}) has no 'message' field.")

        llm_content = message.get("content")
        if not llm_content:
            raise ValueError(f"Message content is missing or empty for task {index} (ID: {custom_id}).")

        # 清理和解析 JSON 内容
        cleaned_content = llm_content.strip()
        if cleaned_content.startswith("```json") and cleaned_content.endswith("```"):
            cleaned_content = cleaned_content[7:-3].strip()
        elif cleaned_content.startswith("```") and cleaned_content.endswith("```"):
            cleaned_content = cleaned_content[3:-3].strip()

        json_start = cleaned_content.find("{")
        json_end = cleaned_content.rfind("}")
        if json_start != -1 and json_end != -1 and json_end > json_start:
            cleaned_content = cleaned_content[json_start:json_end+1]
        else:
             raise ValueError(f"Could not find valid JSON structure in cleaned content for task {index} (ID: {custom_id}).")

        parsed_llm_output = json.loads(cleaned_content)
        final_result = {"idx": index, **parsed_llm_output}
        append_to_jsonl(result_file, final_result, output_lock)

    # 捕获处理和解析 LLM 输出时的各种错误
    except (json.JSONDecodeError, ValueError, TypeError, IndexError) as e:
        error_msg = f"Failed to process/parse LLM output for task {index} (ID: {custom_id}): {str(e)}"
        print(f"Worker Batch 错误: {error_msg}。跳过此任务，将在下次运行时重试。")
        if llm_content is not None:
            print(f"原始 LLM 输出 (部分): {llm_content[:200]}...")
        else:
            if response_body: print(f"原始 Response Body (部分): {str(response_body)[:200]}...")
            else: print(f"原始 Response Body 不可用。")

        return

# --- 修改：处理分配给单个 Worker 的任务的函数，增加 batch_size 控制 ---
def process_tasks_for_worker(worker_ip, port, tasks, result_file, output_lock, batch_size):
    """为一个 Worker IP 处理分配给它的任务列表, 按 batch_size 分批提交"""
    worker_id_str = f"Worker-{worker_ip}:{port}"
    total_tasks_for_worker = len(tasks)
    print(f"[{worker_id_str}] 开始处理 {total_tasks_for_worker} 个任务，每批最多 {batch_size} 个...")

    client = openai.Client(base_url=f"http://{worker_ip}:{port}/v1", api_key="None")
    processed_count = 0

    for i in range(0, total_tasks_for_worker, batch_size):
        batch_tasks = tasks[i : i + batch_size]
        batch_index = i // batch_size + 1
        num_batches = math.ceil(total_tasks_for_worker / batch_size)
        batch_info_str = f"批次 {batch_index}/{num_batches} ({len(batch_tasks)} 个任务)"

        if not batch_tasks: continue
        print(f"[{worker_id_str}] 开始处理 {batch_info_str}...")

        batch_input_file = None
        batch_response = None
        result_file_id = None
        error_file_id = None
        batch_input_file_path = None

        try:
            # 1. 准备并上传 Batch 输入文件 (逻辑不变)
            batch_requests = [create_batch_request_item(config, config["idx"]) for config in batch_tasks]
            batch_input_file_path = f"batch_input_{worker_ip}_batch{batch_index}_{uuid.uuid4()}.jsonl"
            with open(batch_input_file_path, "w", encoding="utf-8") as f:
                for req in batch_requests: f.write(json.dumps(req) + "\n")
            with open(batch_input_file_path, "rb") as f:
                batch_input_file = client.files.create(file=f, purpose="batch")

            # 2. 创建 Batch Job (逻辑不变)
            batch_response = client.batches.create(
                input_file_id=batch_input_file.id,
                endpoint="/v1/chat/completions",
                completion_window="24h",
                metadata={'worker_ip': worker_ip, 'batch_index': batch_index, 'num_in_batch': len(batch_tasks)}
            )
            print(f"[{worker_id_str}] {batch_info_str}: Batch Job 已创建，Batch ID: {batch_response.id}, 状态: {batch_response.status}")

            # 3. 轮询 Batch Job 状态 (逻辑不变)
            while batch_response.status not in ["completed", "failed", "cancelled", "expired"]:
                wait_time = 20
                time.sleep(wait_time)
                try:
                    batch_response = client.batches.retrieve(batch_response.id)
                except openai.NotFoundError:
                    print(f"[{worker_id_str}] {batch_info_str}: 轮询时 Batch Job (ID: {batch_response.id}) 未找到。将视为失败。")
                    batch_response.status = "failed" # 模拟失败状态
                    break
                except openai.APIError as poll_error:
                     print(f"[{worker_id_str}] {batch_info_str}: 轮询 Batch Job (ID: {batch_response.id}) 时出错: {poll_error}。将重试...")
                     time.sleep(wait_time)

            print(f"[{worker_id_str}] {batch_info_str}: Batch Job 最终状态: {batch_response.status} (ID: {batch_response.id})")

            # 4. 处理结果 (***核心修改处***)
            if batch_response.status == "completed":
                if hasattr(batch_response, 'request_counts') and batch_response.request_counts:
                     print(f"[{worker_id_str}] {batch_info_str}: 请求计数: Total={batch_response.request_counts.total}, Completed={batch_response.request_counts.completed}, Failed={batch_response.request_counts.failed}")

                result_file_id = batch_response.output_file_id
                error_file_id = batch_response.error_file_id # 获取错误文件ID

                # 处理结果文件
                if result_file_id:
                    print(f"[{worker_id_str}] {batch_info_str}: 正在下载并处理结果文件: {result_file_id}")
                    try:
                        result_content_response = client.files.content(result_file_id)
                        result_content = result_content_response.read().decode("utf-8")

                        result_lines = result_content.strip().split("\n")
                        num_results_in_batch = 0
                        for line in result_lines:
                            if line.strip():
                                try:
                                    batch_line_json = json.loads(line)
                                    # 调用行处理函数，该函数会负责解析和写入文件
                                    process_batch_result_line(batch_line_json, result_file, output_lock)
                                    # 仅在 process_batch_result_line 没有引发异常（即成功处理或已记录错误）时计数
                                    num_results_in_batch += 1
                                except json.JSONDecodeError:
                                    print(f"[{worker_id_str}] {batch_info_str}: 警告: 跳过无效的 Batch 结果 JSONL 行: {line[:100]}...")
                        print(f"[{worker_id_str}] {batch_info_str}: 处理了 {num_results_in_batch} 行结果。")
                        processed_count += num_results_in_batch # 累加处理成功的数量 (注意：现在这包括了被 process_batch_result_line 记录为错误的任务)
                    except openai.APIError as download_error:
                         # 下载或读取结果文件出错，不记录错误，下次重试
                         print(f"[{worker_id_str}] {batch_info_str}: 下载或处理结果文件 {result_file_id} 时出错: {download_error}. 此批次任务将在下次运行时重试。")
                else:
                     # 没有输出文件ID，视为失败，不记录错误，下次重试
                    print(f"[{worker_id_str}] {batch_info_str}: 警告: Batch Job 已完成但没有输出文件 ID。此批次任务可能未完全成功，未记录的任务将在下次运行时重试。")

                # 检查是否有错误文件
                if error_file_id:
                     print(f"[{worker_id_str}] {batch_info_str}: Batch Job 存在错误文件 (File ID: {error_file_id})。错误详情已在结果文件中记录（如果可解析）。请检查。")
                     # 错误文件内容通常已反映在 result_file 的 error 字段中（由 process_batch_result_line 处理）
                     # 或者可以考虑下载并打印错误文件内容以供调试
                     # try:
                     #     error_content = client.files.content(error_file_id).read().decode("utf-8")
                     #     print(f"[{worker_id_str}] {batch_info_str}: 错误文件内容:\n{error_content}")
                     # except Exception as e:
                     #     print(f"[{worker_id_str}] {batch_info_str}: 下载错误文件 {error_file_id} 时出错: {e}")

            # 如果 Batch Job 本身失败/取消/过期
            elif batch_response.status in ["failed", "cancelled", "expired"]:
                 # 不记录错误到 result_file，以便下次重试
                print(f"[{worker_id_str}] {batch_info_str}: Batch Job 未成功完成 (状态: {batch_response.status})。此批次任务将在下次运行时重试。")
                failed_indices_in_batch = [t['idx'] for t in batch_tasks]
                print(f"[{worker_id_str}] {batch_info_str}: 涉及的任务索引: {failed_indices_in_batch}")
                if hasattr(batch_response, 'errors') and batch_response.errors:
                     print(f"[{worker_id_str}] {batch_info_str}: Batch Job 顶层错误详情: {batch_response.errors}")

        # 捕获创建/轮询过程中的 API 错误
        except openai.APIError as e:
             # 不记录错误到 result_file，以便下次重试
            print(f"[{worker_id_str}] {batch_info_str}: 发生 OpenAI API 错误: {e} (Status: {e.status_code}, Type: {e.type}). 此批次任务将在下次运行时重试。")
            failed_indices_in_batch = [t['idx'] for t in batch_tasks]
            print(f"[{worker_id_str}] {batch_info_str}: 涉及的任务索引: {failed_indices_in_batch}")
        # 捕获其他意外错误
        except Exception as e:
            # 不记录错误到 result_file，以便下次重试
            import traceback
            print(f"[{worker_id_str}] {batch_info_str}: 处理过程中发生未知错误: {str(e)}. 此批次任务将在下次运行时重试。")
            traceback.print_exc()
            failed_indices_in_batch = [t['idx'] for t in batch_tasks]
            print(f"[{worker_id_str}] {batch_info_str}: 涉及的任务索引: {failed_indices_in_batch}")
            if batch_response:
                 print(f"[{worker_id_str}] {batch_info_str}: 错误发生时的 Batch Job 状态: {batch_response.status} (ID: {batch_response.id})")

        finally:
            # 5. 清理文件 (逻辑不变)
            try:
                if batch_input_file: client.files.delete(batch_input_file.id)
            except Exception as e: print(f"[{worker_id_str}] {batch_info_str}: 删除输入文件 {batch_input_file.id if batch_input_file else 'N/A'} 时出错: {e}")
            try:
                if result_file_id: client.files.delete(result_file_id)
            except Exception as e: print(f"[{worker_id_str}] {batch_info_str}: 删除结果文件 {result_file_id or 'N/A'} 时出错: {e}")
            try:
                if error_file_id: client.files.delete(error_file_id) # 确保也删除错误文件
            except Exception as e: print(f"[{worker_id_str}] {batch_info_str}: 删除错误文件 {error_file_id or 'N/A'} 时出错: {e}")
            if batch_input_file_path and os.path.exists(batch_input_file_path):
                try: os.remove(batch_input_file_path)
                except OSError as e: print(f"[{worker_id_str}] {batch_info_str}: 删除本地临时文件 {batch_input_file_path} 时出错: {e}")

        print(f"[{worker_id_str}] 完成 {batch_info_str}。")

    print(f"[{worker_id_str}] 处理完成，本次运行处理了 {processed_count} 行结果（包括成功和记录的错误）。未记录的任务将在下次运行时重试。")


# --- 修改 main 函数以接收和传递 batch_size ---
def main(master_url, ip_update_interval, json_file_path, result_file, worker_port, batch_size):

    ip_manager = WorkerIPManager(master_url, update_interval=ip_update_interval)
    output_lock = threading.Lock()
    
    # 读取输入JSON文件
    try:
        with open(json_file_path, "r", encoding="utf-8") as f:
            all_configs = json.load(f)
        print(f"输入配置总数: {len(all_configs)}")
    except FileNotFoundError:
        print(f"错误: 输入文件未找到 {json_file_path}")
        ip_manager.stop_update_thread()
        return
    except json.JSONDecodeError:
        print(f"错误: 无法解析输入文件 {json_file_path}")
        ip_manager.stop_update_thread()
        return
    
    # 检查是否存在已生成的结果文件，读取已处理的索引
    processed_indices = set()
    if os.path.exists(result_file):
        results = read_jsonl(result_file)
        processed_indices = {item["idx"] for item in results if isinstance(item.get("idx"), int)} # 只恢复整数索引
        print(f"从已有结果恢复: {len(processed_indices)} 个已完成结果")
    
    # 过滤出未处理的任务
    task_pool = [config for config in all_configs if isinstance(config, dict) and config.get("idx") not in processed_indices]
    print(f"待处理任务数量: {len(task_pool)}")
    
    if not task_pool:
        print("没有需要处理的新任务。")
        ip_manager.stop_update_thread()
        return

    # 获取可用的 Worker IPs
    available_ips = ip_manager.get_cached_ips()

    if not available_ips:
        print("错误: 没有可用的 Worker IP 地址，无法处理任务。请检查 Master 服务或网络连接。")
        ip_manager.stop_update_thread()
        return

    print(f"将使用 {len(available_ips)} 个 Worker IP 进行处理: {available_ips}")

    # 将任务分配给 Worker IPs
    num_workers = len(available_ips)
    tasks_per_worker = math.ceil(len(task_pool) / num_workers)
    worker_tasks = {}
    for i, worker_ip in enumerate(available_ips):
        start_index = i * tasks_per_worker
        end_index = min((i + 1) * tasks_per_worker, len(task_pool))
        assigned_tasks = task_pool[start_index:end_index]
        if assigned_tasks: # 只有分配到任务才添加
             worker_tasks[worker_ip] = assigned_tasks

    print(f"任务已分配，每个 Worker 大约处理 {tasks_per_worker} 个任务。")

    # 使用 ThreadPoolExecutor 并行提交 Batch Job 给每个 Worker
    # 使用与 Worker IP 数量相同的线程数
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = []
        for worker_ip, tasks in worker_tasks.items():
            print(f"提交任务给 Worker: {worker_ip} (共 {len(tasks)} 个任务，分批处理)")
            futures.append(executor.submit(
                process_tasks_for_worker,
                worker_ip,
                worker_port,
                tasks,
                result_file,
                output_lock,
                batch_size # 传递 batch_size
            ))

        # 等待所有 Worker 的 Batch Job 处理完成
        print("等待所有 Worker 完成所有批次的任务...")
        for future in futures:
            try:
                future.result()
            except Exception as e:
                # 线程内部的异常应该已经在 process_tasks_for_worker 中捕获和记录
                # 但这里可以再记录一次，以防万一
                print(f"一个 Worker 线程执行时遇到顶层错误: {e}")
        
    print("所有 Worker 的任务已处理完毕。")
    ip_manager.stop_update_thread() # 停止 IP 更新线程
        

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='使用多个 Worker 的 Batch API 并行生成竞赛问题')
    parser.add_argument('--master_url', type=str, required=True, help='主服务器URL (用于获取 Worker IP 列表)')
    parser.add_argument('--worker_port', type=int, required=True, help='运行 Batch API 的 Worker 服务端口')
    parser.add_argument('--ip_update_interval', type=int, default=120, help='Worker IP 更新间隔（秒）')
    parser.add_argument('--json_file_path', type=str, required=True, help='输入JSON文件路径')
    parser.add_argument('--result_file', type=str, default="generated_problems.jsonl", help='结果输出文件路径(JSONL格式)')
    # --- 新增 batch_size 参数 ---
    parser.add_argument('--batch_size', type=int, default=1000, help='每个 Batch Job 中包含的最大请求数量')
    
    args = parser.parse_args()
    
    # --- 检查 batch_size 是否合法 ---
    if args.batch_size <= 0:
        print("错误: --batch_size 必须是正整数。")
        exit(1)
    
    main(
        master_url=args.master_url,
        ip_update_interval=args.ip_update_interval,
        json_file_path=args.json_file_path,
        result_file=args.result_file,
        worker_port=args.worker_port,
        batch_size=args.batch_size # 传递 batch_size
    )

# 执行示例 (更新):
# python3.12 question_gen.py \
# --master_url http://10.33.0.155:5001 \
# --worker_port 30000 \
# --ip_update_interval 120 \
# --json_file_path /root/epicoder2/data_syn/data/competition_features_parse_80k.json \
# --result_file /root/epicoder2/data_syn/data/competition_features_parse_80k_question.jsonl \
# --batch_size 256