import requests
import json
import time
import threading
from concurrent.futures import ThreadPoolExecutor
import os
import random
import argparse
import openai
from datetime import datetime

with open('cyaron_prompt.txt', 'r') as f:
    system_prompt = f.read()


# Function to get worker IPs from master server
class WorkerIPManager:
    def __init__(self, master_url, update_interval=300):
        self.master_url = master_url
        self.update_interval = update_interval
        self.worker_ips = []
        self.update_ips()

    def get_worker_ips(self):
        try:
            response = requests.get(f"{self.master_url}/workers/ips", timeout=5)
            response.raise_for_status()
            return response.json().get('ips', [])
        except requests.exceptions.RequestException as e:
            print(f"Failed to get worker IPs: {str(e)}")
            return []
        except KeyError:
            print("Invalid response format from master server")
            return []

    def update_ips(self):
        self.worker_ips = self.get_worker_ips()
        # Schedule the next update
        threading.Timer(self.update_interval, self.update_ips).start()

    def get_cached_ips(self):
        return self.worker_ips

def call_llm_api(messages, worker_ip, port=30000, max_retries=3, worker_status=None):
    """Call the local LLM API using a specific worker IP"""
    client = openai.Client(base_url=f"http://{worker_ip}:{port}/v1", api_key="None")
    
    
    retries = 0
    while retries < max_retries:
        try:
            response = client.chat.completions.create(
                model="deepseek-ai/DeepSeek-V3",
                messages=messages,
                temperature=0.6,
                max_tokens=32768
            )
            
            # Extract answer content
            answer = response.choices[0].message.content
            
            # Update last success time if worker_status is provided
            if worker_status and worker_ip in worker_status:
                worker_status[worker_ip]["last_success"] = datetime.now()
            
            return answer
            
        except Exception as e:
            print(f"API call failed to {worker_ip}: {str(e)}. Retrying in 1 second... (Attempt {retries + 1}/{max_retries})")
            time.sleep(1)
            retries += 1
    
    print(f"Maximum retries reached. Failed to process prompt using worker {worker_ip}.")
    return None

def process_question(sample, index, worker_ips):
    """Process a single features config and save the result"""

    # Construct the English prompt with updated instructions
    messages = [{"role": "system", "content": system_prompt}, 
                {"role": "user", "content": sample["question"]}]
    

    # Call the API to get the result
    print(f"Processing config {index}...")
    selected_ip = random.choice(worker_ips)
    response = call_llm_api(messages, selected_ip)
    
    # 只有在获取到有效回答时才返回结果
    if response is not None:
        result = {
            "index": index,
            "question": sample["question"],
            "answer": sample["answer"],
            "test_code": response
        }
        return result
    else:
        # 没有获取到有效回答，返回None或带错误标记的对象
        return None  # 或者 return {"config_index": index, "error": True}

def main(master_url, parallel_num, ip_update_interval, json_file_path, result_file):
    ip_manager = WorkerIPManager(master_url, update_interval=ip_update_interval)
    
    # Read from the JSON file
    with open(json_file_path, "r", encoding="utf-8") as f:
        samples = json.load(f)
    print(f"Length of samples: {len(samples)}")
    
    # 创建结果池
    result_pool = []
    
    # 检查是否存在已生成的结果文件
    processed_indices = set()
    
    if os.path.exists(result_file):
        with open(result_file, "r", encoding="utf-8") as f:
            result_pool = json.load(f)
            processed_indices = {item["index"] for item in result_pool}
            print(f"从已有结果恢复: {len(result_pool)} 个已完成结果")
    
    # 过滤出未处理的任务
    task_pool = [sample for sample in samples if sample["index"] not in processed_indices]
    print(f"待处理任务数量: {len(task_pool)}")
    
    # 处理任务直到任务池为空
    while task_pool:
        # 取出当前批次的任务，但不从任务池中移除
        current_batch = task_pool[:parallel_num]
        
        with ThreadPoolExecutor(max_workers=parallel_num) as executor:
            worker_ips = ip_manager.get_cached_ips()
            
            # 提交当前批次的任务
            future_to_config = {
                executor.submit(process_question, sample, sample["index"], worker_ips): sample 
                for sample in current_batch
            }
            
            # 跟踪成功处理的任务索引
            successful_indices = set()
            
            # 处理结果
            for future in future_to_config:
                sample = future_to_config[future]
                index = sample["index"]
                try:
                    result = future.result()
                    if result is None:  # 检查是否获取到有效结果
                        print(f"任务 {index} 处理失败，将在下一批次重试")
                    else:
                        result_pool.append(result)
                        successful_indices.add(index)
                        print(f"成功处理配置 {index}")
                except Exception as e:
                    print(f"处理配置 {index} 时出错: {str(e)}，将在下一批次重试")
        
        # 每个批次处理完成后，从任务池中移除成功的任务
        task_pool = [sample for sample in task_pool if sample["index"] not in successful_indices]
        
        # 保存完整结果
        with open(result_file, "w", encoding="utf-8") as f:
            json.dump(result_pool, f, ensure_ascii=False, indent=2)
        print(f"已保存当前结果，共 {len(result_pool)} 条数据，剩余 {len(task_pool)} 个任务待处理")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='生成竞赛问题')
    parser.add_argument('--master_url', type=str, default="http://10.33.0.5:5000", help='主服务器URL')
    parser.add_argument('--parallel_num', type=int, default=128, help='并行处理任务数量')
    parser.add_argument('--ip_update_interval', type=int, default=120, help='IP更新间隔（秒）')
    parser.add_argument('--json_file_path', type=str, default="./competition_features_parse_20k.json", help='输入JSON文件路径')
    parser.add_argument('--result_file', type=str, default="generated_problems.json", help='结果输出文件路径')
    
    args = parser.parse_args()
    
    main(
        master_url=args.master_url,
        parallel_num=args.parallel_num,
        ip_update_interval=args.ip_update_interval,
        json_file_path=args.json_file_path,
        result_file=args.result_file
    )


# python3 testcase_generation_cyaron.py \
# --master_url http://10.33.0.5:5001 \
# --parallel_num 16 \
# --ip_update_interval 180 \
# --json_file_path ./competitive_feature_merged_80k_generated_10k_qa.json \
# --result_file ./competitive_feature_merged_80k_generated_qa_testcode.json


# python3 testcode_generation.py --input /root/generated_problems.json --output results/V3_problems_testcode.jsonl --prompt openr1_prompt.txt --samples 3 --workers 32 --master_url http://10.33.0.5:5001 --checkpoint_dir ./test_case_checkpoints

