import json
import requests
import logging
import time
import os
import threading
import random
import argparse
import sys
from datetime import datetime
from typing import List, Dict, Any, Tuple, Optional
from collections import deque
from concurrent.futures import ThreadPoolExecutor
import re

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.FileHandler(f'test_gen_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'), logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

class SGLangClient:
    """SGLang API客户端"""
    def __init__(self, port: int = 8000, base_url: str = "http://localhost"):
        self.api_endpoint = f"{base_url}:{port}/v1/chat/completions"
    
    def generate(self, prompt: str, model: str = "deepseek-ai/DeepSeek-R1", 
                temperature: float = 0.6, max_tokens: int = 32768) -> Tuple[str, float]:
        try:
            data = {
                "model": model,
                "messages": [{"role": "user", "content": prompt}],
                "temperature": temperature,
                "max_tokens": max_tokens
            }
            
            start_time = time.time()
            response = requests.post(self.api_endpoint, json=data, timeout=30)
            response_time = time.time() - start_time
            
            response.raise_for_status()
            result = response.json()
            content = result.get("choices", [{}])[0].get("message", {}).get("content", "")
            return content, response_time
        except Exception as e:
            logger.error(f"API调用错误: {str(e)}")
            return "", 0

class ServerPool:
    """服务器池管理"""
    def __init__(self, servers: List[Dict[str, Any]] = None, master_url: str = None):
        self.servers = servers or []
        self.master_url = master_url
        self.lock = threading.Lock()
        
        # 从主服务器获取服务器列表
        if master_url:
            self._update_server_list()
            threading.Thread(target=self._periodic_update, daemon=True).start()
    
    def _update_server_list(self):
        """从主服务器更新服务器列表"""
        if not self.master_url:
            return
            
        try:
            url = f"{self.master_url}/workers/ips"
            response = requests.get(url, timeout=10)
            response.raise_for_status()
            
            data = response.json()
            if "ips" in data:
                with self.lock:
                    self.servers = [{"base_url": f"http://{ip}", "port": 30000} for ip in data["ips"]]
                logger.info(f"更新服务器列表成功，获取到 {len(self.servers)} 台服务器")
        except Exception as e:
            logger.error(f"更新服务器列表失败: {e}")
    
    def _periodic_update(self):
        """定期更新服务器列表"""
        while True:
            time.sleep(300)  # 每5分钟更新一次
            self._update_server_list()
    
    def get_server(self) -> Optional[Dict[str, Any]]:
        """获取一个可用的服务器"""
        with self.lock:
            if not self.servers:
                if self.master_url:
                    self._update_server_list()
                if not self.servers:
                    return None
            return random.choice(self.servers)

def load_questions(input_file: str, start_from: int = 0, max_problems: int = None):
    """
    从JSON文件加载问题数据，更通用的实现方式
    
    参数:
        input_file: JSON文件路径
        start_from: 从第几个问题开始处理
        max_problems: 最多处理多少个问题
        
    返回:
        规范化后的问题列表
    """
    logger.info(f"开始从 {input_file} 加载问题")
    
    try:
        with open(input_file, 'r', encoding='utf-8') as f:
            questions = json.load(f)
        
        if not isinstance(questions, list):
            logger.error(f"输入文件不是有效的问题列表")
            return []
            
        logger.info(f"成功从文件加载了 {len(questions)} 个原始问题")
        
        # 规范化处理
        normalized_questions = []
        for i, question in enumerate(questions):
            if not isinstance(question, dict):
                logger.warning(f"问题 #{i} 不是有效的字典格式，尝试转换")
                try:
                    if isinstance(question, str):
                        # 尝试将字符串解析为JSON
                        question_dict = {"question": question, "config_index": str(i)}
                        normalized_questions.append(question_dict)
                    else:
                        # 其他类型转为字符串
                        question_dict = {"question": str(question), "config_index": str(i)}
                        normalized_questions.append(question_dict)
                except Exception as e:
                    logger.warning(f"转换问题 #{i} 失败: {e}")
                continue
                
            # 处理字典格式的问题
            normalized_question = question.copy()
            
            # 确保有config_index
            if "config_index" not in normalized_question:
                normalized_question["config_index"] = str(i)
            
            # 处理question字段
            if "question" in normalized_question:
                q_content = normalized_question["question"]
                
                # 如果question是字典，转换为格式化文本
                if isinstance(q_content, dict):
                    normalized_question["question"] = format_question_dict(q_content)
                # 如果不是字符串，转换为字符串
                elif not isinstance(q_content, str):
                    normalized_question["question"] = str(q_content)
            else:
                # 如果没有question字段，尝试从其他字段构建问题描述
                question_text = format_question_from_fields(normalized_question)
                if question_text:
                    normalized_question["question"] = question_text
                else:
                    # 如果无法构建问题描述，跳过此问题
                    logger.warning(f"问题 #{i} 缺少问题描述，已跳过")
                    continue
            
            normalized_questions.append(normalized_question)
        
        logger.info(f"成功规范化了 {len(normalized_questions)} 个问题")
        
        # 截取所需范围
        if start_from > 0:
            normalized_questions = normalized_questions[start_from:]
        if max_problems is not None:
            normalized_questions = normalized_questions[:max_problems]
            
        logger.info(f"最终返回 {len(normalized_questions)} 个问题")
        return normalized_questions
        
    except json.JSONDecodeError as e:
        logger.error(f"JSON解析错误: {e}")
        return []
    except Exception as e:
        logger.error(f"加载问题时出错: {e}")
        return []

def format_question_dict(question_dict: dict) -> str:
    """
    将问题字典格式化为字符串
    """
    formatted_text = ""
    
    # 常见的问题字段名称集合
    title_keys = ["title", "problem_title", "name", "problem_name"]
    statement_keys = ["statement", "description", "problem_statement", "content", "body"]
    input_keys = ["input", "input_description", "input_format", "inputformat"]
    output_keys = ["output", "output_description", "output_format", "outputformat"]
    constraints_keys = ["constraints", "constraint", "limits", "restrictions"]
    example_keys = ["examples", "example", "sample", "samples", "testcases", "test_cases"]
    time_limit_keys = ["time_limit", "timelimit", "time"]
    memory_limit_keys = ["memory_limit", "memorylimit", "memory"]
    
    # 添加标题
    for key in title_keys:
        if key in question_dict:
            formatted_text += f"# {question_dict[key]}\n\n"
            break
    
    # 添加时间和内存限制
    limits = []
    for key in time_limit_keys:
        if key in question_dict:
            limits.append(f"Time Limit: {question_dict[key]}")
            break
    
    for key in memory_limit_keys:
        if key in question_dict:
            limits.append(f"Memory Limit: {question_dict[key]}")
            break
    
    if limits:
        formatted_text += ", ".join(limits) + "\n\n"
    
    # 添加问题描述
    for key in statement_keys:
        if key in question_dict:
            formatted_text += f"{question_dict[key]}\n\n"
            break
    
    # 添加输入格式
    for key in input_keys:
        if key in question_dict:
            formatted_text += f"Input Format:\n{question_dict[key]}\n\n"
            break
    
    # 添加输出格式
    for key in output_keys:
        if key in question_dict:
            formatted_text += f"Output Format:\n{question_dict[key]}\n\n"
            break
    
    # 添加约束条件
    for key in constraints_keys:
        if key in question_dict:
            constraints = question_dict[key]
            if isinstance(constraints, list):
                formatted_text += "Constraints:\n"
                for constraint in constraints:
                    formatted_text += f"- {constraint}\n"
                formatted_text += "\n"
            else:
                formatted_text += f"Constraints:\n{constraints}\n\n"
            break
    
    # 添加示例
    for key in example_keys:
        if key in question_dict:
            examples = question_dict[key]
            if isinstance(examples, list):
                for i, example in enumerate(examples, 1):
                    formatted_text += f"Example {i}:\n"
                    if isinstance(example, dict):
                        if "input" in example:
                            formatted_text += f"Input:\n{example['input']}\n"
                        if "output" in example:
                            formatted_text += f"Output:\n{example['output']}\n"
                        if "explanation" in example:
                            formatted_text += f"Explanation:\n{example['explanation']}\n"
                    else:
                        formatted_text += f"{example}\n"
                    formatted_text += "\n"
            elif isinstance(examples, dict):
                formatted_text += "Example:\n"
                if "input" in examples:
                    formatted_text += f"Input:\n{examples['input']}\n"
                if "output" in examples:
                    formatted_text += f"Output:\n{examples['output']}\n"
                if "explanation" in examples:
                    formatted_text += f"Explanation:\n{examples['explanation']}\n"
                formatted_text += "\n"
            else:
                formatted_text += f"Examples:\n{examples}\n\n"
            break
    
    # 添加其他可能的字段
    for key, value in question_dict.items():
        if key not in (title_keys + statement_keys + input_keys + output_keys + 
                      constraints_keys + example_keys + time_limit_keys + memory_limit_keys):
            # 跳过常见的元数据字段
            if key.lower() in ["id", "url", "source", "difficulty", "tags", "category", "type"]:
                continue
                
            if isinstance(value, str) and value:
                formatted_text += f"{key.replace('_', ' ').title()}:\n{value}\n\n"
    
    return formatted_text.strip()

def format_question_from_fields(question_obj: dict) -> str:
    """
    从对象的各个字段构建问题描述
    """
    # 如果没有question字段，尝试从整个对象构建问题
    question_parts = []
    
    # 检查常见问题相关字段
    potential_fields = [
        "title", "problem", "description", "statement", "prompt", 
        "task", "challenge", "content", "text"
    ]
    
    for field in potential_fields:
        if field in question_obj and isinstance(question_obj[field], str) and question_obj[field]:
            question_parts.append(question_obj[field])
    
    # 如果找到了问题相关字段，组合它们
    if question_parts:
        return "\n\n".join(question_parts)
    
    # 如果没有找到常见问题字段，尝试构建问题描述
    # 这里我们递归地处理嵌套字典
    for key, value in question_obj.items():
        # 跳过一些不太可能包含问题描述的字段
        if key.lower() in ["config_index", "id", "selected_features", "tags", "metadata"]:
            continue
            
        if isinstance(value, dict):
            # 递归处理嵌套字典
            nested_question = format_question_dict(value)
            if nested_question:
                return nested_question
        elif isinstance(value, str) and len(value) > 100:
            # 如果是较长的字符串，可能是问题描述
            return value
    
    # 如果实在找不到问题描述，返回空字符串
    return ""

def generate_test_cases(
    input_file: str,
    output_file: str,
    prompt_template: str,
    servers: List[Dict[str, Any]] = None,
    master_url: str = None,
    samples_per_problem: int = 3,
    max_retries: int = 3,
    num_workers: int = 8,
    start_from: int = 0,
    max_problems: int = None,
    checkpoint_dir: str = "./checkpoints",
    append_mode: bool = False
):
    """生成测试用例"""
    # 创建检查点目录
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    checkpoint_file = os.path.join(checkpoint_dir, "processed_ids.json")
    
    # 加载已处理的问题ID
    processed_ids = set()
    if os.path.exists(checkpoint_file):
        try:
            with open(checkpoint_file, 'r') as f:
                processed_ids = set(json.load(f))
        except Exception as e:
            logger.error(f"加载检查点失败: {e}")
    
    # 加载问题数据
    problems = load_questions(input_file, start_from, max_problems)
    if not problems:
        logger.error("没有找到可处理的问题")
        return 0, 0
    
    # 备份输出文件
    if os.path.exists(output_file) and not append_mode:
        backup_file = f"{output_file}.bak.{int(time.time())}"
        os.rename(output_file, backup_file)
        logger.info(f"原输出文件已备份为: {backup_file}")
    
    # 创建服务器池
    server_pool = ServerPool(servers=servers, master_url=master_url)
    
    # 创建任务列表
    tasks = []
    for problem in problems:
        problem_id = str(problem.get("config_index", ""))
        if problem_id in processed_ids:
            logger.info(f"跳过已处理的问题: {problem_id}")
            continue
            
        for sample_id in range(samples_per_problem):
            tasks.append((problem, sample_id, problem_id))
    
    logger.info(f"创建了 {len(tasks)} 个任务")
    
    # 线程安全的结果列表和任务队列
    results = []
    task_queue = deque(tasks)
    queue_lock = threading.Lock()
    result_lock = threading.Lock()
    
    # 工作线程函数
    def worker():
        while True:
            # 获取任务
            with queue_lock:
                if not task_queue:
                    break
                problem, sample_id, problem_id = task_queue.popleft()
            
            # 处理任务
            for attempt in range(max_retries):
                # 获取服务器
                server = server_pool.get_server()
                if not server:
                    logger.error("无法获取可用服务器")
                    time.sleep(5)
                    continue
                
                try:
                    # 提取问题内容，确保是字符串格式
                    if isinstance(problem, dict):
                        question = problem.get("question", "")
                        if isinstance(question, dict):
                            question = format_question_dict(question)
                        elif not isinstance(question, str):
                            question = str(question)
                    elif isinstance(problem, str):
                        question = problem
                    else:
                        question = str(problem)
                    
                    # 确保问题内容非空
                    if not question:
                        logger.error(f"问题 {problem_id} 内容为空")
                        break
                    
                    # 构建提示
                    prompt = prompt_template.replace("{problem}", question)
                    
                    # 调用API生成测试用例
                    client = SGLangClient(port=server["port"], base_url=server["base_url"])
                    test_case, response_time = client.generate(prompt)
                    
                    if not test_case:
                        logger.warning(f"问题 {problem_id} 第 {attempt+1} 次尝试失败，未获取内容")
                        time.sleep(2)
                        continue
                    
                    # 保存结果
                    result = {
                        "problem_id": problem_id,
                        "sample_id": sample_id,
                        "question": question,
                        "test_case": test_case,
                        "timestamp": datetime.now().isoformat()
                    }
                    
                    with result_lock:
                        results.append(result)
                        processed_ids.add(problem_id)
                        # 定期保存检查点
                        with open(checkpoint_file, 'w') as f:
                            json.dump(list(processed_ids), f)
                        
                        # 追加保存结果
                        with open(output_file, 'a') as f:
                            json.dump(result, f, ensure_ascii=False)
                            f.write('\n')
                    
                    logger.info(f"问题 {problem_id} 处理成功，耗时 {response_time:.2f}s")
                    break
                    
                except Exception as e:
                    logger.error(f"问题 {problem_id} 处理出错: {str(e)}")
                    time.sleep(2)
            else:
                # 所有重试都失败
                logger.error(f"问题 {problem_id} 在 {max_retries} 次尝试后失败")
                with result_lock:
                    results.append({
                        "problem_id": problem_id,
                        "sample_id": sample_id,
                        "error": f"Failed after {max_retries} retries",
                        "timestamp": datetime.now().isoformat()
                    })
    
    # 使用线程池并行处理任务
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        workers = [executor.submit(worker) for _ in range(min(num_workers, len(tasks)))]
        for w in workers:
            w.result()
    
    # 统计成功和失败的数量
    success_count = sum(1 for r in results if "error" not in r)
    fail_count = len(results) - success_count
    
    return success_count, fail_count

def main():
    parser = argparse.ArgumentParser(description='生成代码测试用例')
    parser.add_argument('--input', type=str, required=True, help='输入问题文件路径')
    parser.add_argument('--output', type=str, required=True, help='输出测试用例文件路径')
    parser.add_argument('--prompt', type=str, required=True, help='提示模板文件路径')
    parser.add_argument('--samples', type=int, default=3, help='每个问题的样例数量')
    parser.add_argument('--workers', type=int, default=8, help='并行工作线程数量')
    parser.add_argument('--checkpoint_dir', type=str, default="./checkpoints", help='检查点文件目录')
    parser.add_argument('--append', action='store_true', help='追加模式，不清空输出文件')
    parser.add_argument('--start_from', type=int, default=0, help='从第几个问题开始处理')
    parser.add_argument('--max_problems', type=int, default=None, help='最多处理问题数量')
    parser.add_argument('--master_url', type=str, default=None, help='主服务器URL')
    
    args = parser.parse_args()
    
    try:
        # 读取prompt模板
        with open(args.prompt, 'r') as f:
            prompt_template = f.read()
        
        # 设置默认服务器列表（如果未提供master_url）
        servers = None
        if not args.master_url:
            servers = [
                {"base_url": "http://10.33.0.29", "port": 30000},
                {"base_url": "http://10.33.0.30", "port": 30000},
                {"base_url": "http://10.33.0.31", "port": 30000}
            ]
        
        # 开始计时
        start_time = time.time()
        
        # 生成测试用例
        success, fail = generate_test_cases(
            args.input,
            args.output,
            prompt_template,
            servers=servers,
            master_url=args.master_url,
            samples_per_problem=args.samples,
            num_workers=args.workers,
            start_from=args.start_from,
            max_problems=args.max_problems,
            checkpoint_dir=args.checkpoint_dir,
            append_mode=args.append
        )
        
        # 打印统计信息
        total_time = time.time() - start_time
        total = success + fail
        success_rate = (success / total * 100) if total > 0 else 0
        
        logger.info(f"处理完成，总耗时: {total_time:.2f}s")
        logger.info(f"成功: {success}，失败: {fail}，成功率: {success_rate:.2f}%")
        
    except KeyboardInterrupt:
        logger.info("程序被用户中断")
        sys.exit(1)
    except Exception as e:
        logger.error(f"程序执行出错: {str(e)}")
        sys.exit(1)

if __name__ == "__main__":
    main()



## 使用示例:
# 使用动态获取服务器列表
# python3 testcode_generation.py \
# --input ../data_syn/generated_problems.json \
# --output V3_problems_testcode.jsonl \
# --prompt openr1_prompt.txt \
# --samples 3 \
# --workers 8 \
# --master_url http://10.33.0.5:5000 \
# --checkpoint_dir ./test_case_checkpoints

# 使用静态配置的服务器列表
# python request.py \
# --input codeforces.jsonl \
# --output test_cases.jsonl \
# --prompt openr1_prompt.txt \
# --samples 3 \
# --workers 8 \
# --use_static_servers \
# --server_list servers.json \
# --checkpoint_dir ./test_case_checkpoints