import json
import time
import os
import random
import argparse
import openai
import uuid
import math
from datetime import datetime
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
import threading
from queue import Queue

def read_json(file_path):
    """从JSON文件读取数据（文件内容为一个list，每个元素是一条数据）"""
    if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
        with open(file_path, "r", encoding="utf-8") as f:
            try:
                return json.load(f)
            except json.JSONDecodeError:
                print(f"警告：无法解析JSON文件 {file_path}")
    return []

def write_json(file_path, data_list):
    """将数据list写入到JSON文件（覆盖写入）"""
    with open(file_path, "w", encoding="utf-8") as f:
        json.dump(data_list, f, ensure_ascii=False, indent=2)

def create_single_request_for_diff(config_item, model_name="deepseek-ai/DeepSeek-V3", temperature=0.0, max_tokens=16):
    """为单个任务创建请求消息"""
    prompt = config_item["prompt"]
    chosen = config_item["chosen"]
    rejected = config_item["rejected"]

    system_prompt = (
        "你是代码评审专家。请比较下面两段代码（chosen 和 rejected），判断它们在代码逻辑或功能实现上是否有明显不同。"
        "如果有明显不同，返回 1；如果只是风格、注释等表面差异，返回 0。"
        "只返回 0 或 1，不要输出其他内容。\n"
        f"- prompt: {prompt}\n"
        f"- chosen:\n```python\n{chosen}\n```\n"
        f"- rejected:\n```python\n{rejected}\n```\n"
    )
    return [{"role": "user", "content": system_prompt}]

def process_single_task(task_data):
    """处理单个任务（用于多进程）"""
    worker_ip, port, task, model_name, temperature, max_tokens, worker_id = task_data
    
    try:
        client = openai.Client(base_url=f"http://{worker_ip}:{port}/v1", api_key="None")
        
        # 构造请求消息
        messages = create_single_request_for_diff(task, model_name, temperature, max_tokens)
        
        # 发送请求
        response = client.chat.completions.create(
            model=model_name,
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens
        )
        
        # 提取结果
        llm_answer = response.choices[0].message.content.strip()
        
        # 构造结果
        result = {
            "idx": task["idx"],
            "prompt": task["prompt"],
            "chosen": task["chosen"],
            "rejected": task["rejected"],
            "model_result": llm_answer,
            "raw_model_output": llm_answer,
            "worker_id": worker_id
        }
        
        return {"success": True, "result": result, "task_idx": task["idx"]}
        
    except Exception as e:
        return {"success": False, "error": str(e), "task_idx": task["idx"], "worker_id": worker_id}

def process_tasks_for_worker_parallel(worker_ip, port, tasks, model_name, temperature, max_tokens, num_processes=4):
    """使用多进程并行处理任务列表"""
    if not tasks:
        print(f"Worker {worker_ip}:{port} 没有分配到任务。")
        return []

    worker_log_prefix = f"[{worker_ip}:{port}]"
    print(f"{worker_log_prefix} 开始处理 {len(tasks)} 个任务（多进程并行模式，进程数: {num_processes}）...")
    
    # 测试连接
    print(f"{worker_log_prefix} 测试连接到 {worker_ip}:{port}...")
    try:
        import requests
        response = requests.get(f"http://{worker_ip}:{port}/v1/models", timeout=10)
        print(f"{worker_log_prefix} 连接测试成功，状态码: {response.status_code}")
        if response.status_code == 200:
            models_data = response.json()
            print(f"{worker_log_prefix} 可用模型: {models_data}")
    except Exception as conn_e:
        print(f"{worker_log_prefix} 连接测试失败: {conn_e}")
    
    # 准备任务数据
    task_data_list = []
    for i, task in enumerate(tasks):
        worker_id = i % num_processes  # 简单的任务分配策略
        task_data = (worker_ip, port, task, model_name, temperature, max_tokens, worker_id)
        task_data_list.append(task_data)
    
    results = []
    processed_count = 0
    failed_count = 0
    
    # 使用进程池并行处理
    with ProcessPoolExecutor(max_workers=num_processes) as executor:
        # 提交所有任务
        future_to_task = {executor.submit(process_single_task, task_data): task_data for task_data in task_data_list}
        
        # 收集结果
        for future in as_completed(future_to_task):
            try:
                result = future.result()
                if result["success"]:
                    results.append(result["result"])
                    processed_count += 1
                else:
                    failed_count += 1
                    print(f"{worker_log_prefix} 任务 {result['task_idx']} 处理失败: {result['error']}")
                
                # 每处理10个任务打印一次进度
                if (processed_count + failed_count) % 10 == 0:
                    print(f"{worker_log_prefix} 已处理 {processed_count + failed_count}/{len(tasks)} 个任务，成功: {processed_count}，失败: {failed_count}")
                    
            except Exception as e:
                failed_count += 1
                task_data = future_to_task[future]
                print(f"{worker_log_prefix} 任务 {task_data[2]['idx']} 执行异常: {e}")
    
    print(f"{worker_log_prefix} 完成处理，成功: {processed_count}，失败: {failed_count}，总计: {len(tasks)}")
    return results

def process_tasks_for_worker_threading(worker_ip, port, tasks, model_name, temperature, max_tokens, num_threads=8):
    """使用多线程并行处理任务列表（备选方案）"""
    if not tasks:
        print(f"Worker {worker_ip}:{port} 没有分配到任务。")
        return []

    worker_log_prefix = f"[{worker_ip}:{port}]"
    print(f"{worker_log_prefix} 开始处理 {len(tasks)} 个任务（多线程并行模式，线程数: {num_threads}）...")
    
    # 测试连接
    print(f"{worker_log_prefix} 测试连接到 {worker_ip}:{port}...")
    try:
        import requests
        response = requests.get(f"http://{worker_ip}:{port}/v1/models", timeout=10)
        print(f"{worker_log_prefix} 连接测试成功，状态码: {response.status_code}")
        if response.status_code == 200:
            models_data = response.json()
            print(f"{worker_log_prefix} 可用模型: {models_data}")
    except Exception as conn_e:
        print(f"{worker_log_prefix} 连接测试失败: {conn_e}")
    
    # 任务队列
    task_queue = Queue()
    result_queue = Queue()
    
    # 将任务放入队列
    for task in tasks:
        task_queue.put(task)
    
    # 添加结束标记
    for _ in range(num_threads):
        task_queue.put(None)
    
    def worker_thread():
        """工作线程函数"""
        client = openai.Client(base_url=f"http://{worker_ip}:{port}/v1", api_key="None")
        
        while True:
            task = task_queue.get()
            if task is None:
                break
                
            try:
                # 构造请求消息
                messages = create_single_request_for_diff(task, model_name, temperature, max_tokens)
                
                # 发送请求
                response = client.chat.completions.create(
                    model=model_name,
                    messages=messages,
                    temperature=temperature,
                    max_tokens=max_tokens
                )
                
                # 提取结果
                llm_answer = response.choices[0].message.content.strip()
                
                # 构造结果
                result = {
                    "idx": task["idx"],
                    "prompt": task["prompt"],
                    "chosen": task["chosen"],
                    "rejected": task["rejected"],
                    "model_result": llm_answer,
                    "raw_model_output": llm_answer
                }
                
                result_queue.put({"success": True, "result": result, "task_idx": task["idx"]})
                
            except Exception as e:
                result_queue.put({"success": False, "error": str(e), "task_idx": task["idx"]})
    
    # 启动工作线程
    threads = []
    for _ in range(num_threads):
        thread = threading.Thread(target=worker_thread)
        thread.start()
        threads.append(thread)
    
    # 收集结果
    results = []
    processed_count = 0
    failed_count = 0
    
    while processed_count + failed_count < len(tasks):
        try:
            result = result_queue.get(timeout=1)
            if result["success"]:
                results.append(result["result"])
                processed_count += 1
            else:
                failed_count += 1
                print(f"{worker_log_prefix} 任务 {result['task_idx']} 处理失败: {result['error']}")
            
            # 每处理10个任务打印一次进度
            if (processed_count + failed_count) % 10 == 0:
                print(f"{worker_log_prefix} 已处理 {processed_count + failed_count}/{len(tasks)} 个任务，成功: {processed_count}，失败: {failed_count}")
                
        except Exception as e:
            print(f"{worker_log_prefix} 收集结果时出错: {e}")
            break
    
    # 等待所有线程完成
    for thread in threads:
        thread.join()
    
    print(f"{worker_log_prefix} 完成处理，成功: {processed_count}，失败: {failed_count}，总计: {len(tasks)}")
    return results

def main(worker_ip, worker_port, json_file_path, result_file, model_name, temperature, max_tokens, 
         parallel_mode="process", num_workers=4):
    print(f"=== 调试信息 ===")
    print(f"Worker IP: {worker_ip}")
    print(f"Worker Port: {worker_port}")
    print(f"输入文件: {json_file_path}")
    print(f"输出文件: {result_file}")
    print(f"模型: {model_name}")
    print(f"温度: {temperature}")
    print(f"最大token: {max_tokens}")
    print(f"并行模式: {parallel_mode}")
    print(f"工作进程/线程数: {num_workers}")
    print(f"=== 开始处理 ===")
    
    try:
        raw_input_configs = read_json(json_file_path)
        if not raw_input_configs and os.path.exists(json_file_path):
             print(f"警告: 输入文件 '{json_file_path}' 为空或不包含有效的JSON数据。")
        elif not raw_input_configs and not os.path.exists(json_file_path):
             print(f"错误: 输入JSON文件未找到 {json_file_path}")
             return

    except Exception as e:
        print(f"错误: 读取或解析输入JSON文件 '{json_file_path}' 时发生意外错误: {e}")
        return
    
    print(f"从输入JSON文件 '{json_file_path}' 加载的总原始配置数: {len(raw_input_configs)}")

    all_task_items = []
    for i, raw_cfg in enumerate(raw_input_configs):
        if not all(k in raw_cfg for k in ("prompt", "chosen", "rejected")):
            print(f"警告: 第{i+1}行缺少必要字段，已跳过。")
            continue
        all_task_items.append({
            "idx": i,
            "prompt": raw_cfg["prompt"],
            "chosen": raw_cfg["chosen"],
            "rejected": raw_cfg["rejected"]
        })
    
    print(f"转换后有效且唯一的任务数: {len(all_task_items)}")

    # 读取已处理的结果
    processed_indices = set()
    existing_results = []
    if os.path.exists(result_file):
        print(f"正在从结果文件 {result_file} 读取已处理的任务...")
        existing_results = read_json(result_file)
        for item in existing_results:
            if isinstance(item, dict) and "idx" in item:
                idx_val = item["idx"]
                if isinstance(idx_val, int):
                    processed_indices.add(idx_val)
                elif isinstance(idx_val, str) and idx_val.isdigit():
                    processed_indices.add(int(idx_val))
        print(f"已恢复 {len(processed_indices)} 个已处理任务的索引。")

    # 过滤出需要处理的任务
    tasks_to_process = [task for task in all_task_items if task["idx"] not in processed_indices]
    print(f"总共需要处理的新任务数 (过滤后): {len(tasks_to_process)}")

    if not tasks_to_process:
        print("没有需要处理的新任务。")
        return
        
    # 根据并行模式选择处理方法
    if parallel_mode == "process":
        new_results = process_tasks_for_worker_parallel(
            worker_ip, 
            worker_port, 
            tasks_to_process, 
            model_name,
            temperature,
            max_tokens,
            num_workers
        )
    elif parallel_mode == "thread":
        new_results = process_tasks_for_worker_threading(
            worker_ip, 
            worker_port, 
            tasks_to_process, 
            model_name,
            temperature,
            max_tokens,
            num_workers
        )
    else:
        print(f"不支持的并行模式: {parallel_mode}")
        return
    
    # 合并结果并写入文件
    all_results = existing_results + new_results
    write_json(result_file, all_results)
    print(f"所有任务已处理完成。结果已写入 {result_file}")
    print(f"本次新增结果: {len(new_results)} 条，总计结果: {len(all_results)} 条")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='使用单个 Worker 并行处理代码差异评分任务，支持断点续传')
    
    parser.add_argument('--worker_ip', type=str, required=True, help='运行服务的 Worker IP 地址')
    parser.add_argument('--worker_port', type=int, required=True, help='运行服务的 Worker 端口')
    
    parser.add_argument('--json_file_path', type=str, required=True, help='输入JSON文件路径 (文件内容为list，每个元素必须包含 "prompt", "chosen", "rejected" 字段)')
    parser.add_argument('--result_file', type=str, default="generated_answers.json", help='结果输出文件路径 (JSON格式)')
    
    parser.add_argument('--model_name', type=str, default="deepseek-ai/DeepSeek-V3", help='LLM模型名称')
    parser.add_argument('--temperature', type=float, default=0.0, help='LLM温度参数')
    parser.add_argument('--max_tokens', type=int, default=16, help='LLM最大生成token数')
    
    parser.add_argument('--parallel_mode', type=str, default="process", choices=["process", "thread"], 
                       help='并行模式: process(多进程) 或 thread(多线程)')
    parser.add_argument('--num_workers', type=int, default=4, help='并行工作进程/线程数量')

    args = parser.parse_args()
        
    main(
        worker_ip=args.worker_ip,
        worker_port=args.worker_port,
        json_file_path=args.json_file_path,
        result_file=args.result_file,
        model_name=args.model_name,
        temperature=args.temperature,
        max_tokens=args.max_tokens,
        parallel_mode=args.parallel_mode,
        num_workers=args.num_workers
    )

# 使用示例:
# python3 critical_diff_score.py \
# --worker_ip 127.0.0.1 \
# --worker_port 30000 \
# --json_file_path ./test_case.json \
# --result_file ./answer_gen_result.json \
# --model_name "deepseek-ai/DeepSeek-V3" \
# --temperature 0.0 \
# --max_tokens 16 \
# --parallel_mode process \
# --num_workers 8



