"""
使用方法:
python3 offline_question_gen.py --model meta-llama/Llama-3.1-8B-Instruct --json_file_path ./特征数据.json --result_file ./生成问题.json --parallel_num 8
"""

import json
import time
import argparse
import dataclasses
import os
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

import sglang as sgl
from sglang.srt.server_args import ServerArgs


def generate_problem(llm, prompt, sampling_params):
    """使用本地模型生成问题"""
    try:
        output = llm.generate(prompt, sampling_params)
        return output["text"]
    except Exception as e:
        print(f"生成过程出错: {str(e)}")
        return None


def process_features(config, index, llm, sampling_params):
    """处理单个特征配置并生成问题"""
    # 提取特征部分
    features_json = json.dumps(config["features"], ensure_ascii=False, indent=4)
    # 规范化换行符
    features_json = features_json.replace("\\n", "\n")

    # 构建提示
    prompt = f"""You are a professional competitive programming problem setter.

Please generate a **"very hard"**, Codeforces-style problem based on the following features.  
The problem must be rigorous, original, and meet all the following difficulty standards.

---

**Difficulty Requirements:**

1. **Algorithmic Complexity**: The problem must require a combination of 2–3 advanced algorithmic techniques.

2. **Thinking Depth**: The problem should rely on discovering hidden properties or structural tricks. Brute-force or direct template solutions must not work.

3. **Time Complexity Optimization**: Brute-force solutions (e.g., O(n²) or worse) should time out. Correct solutions should operate within optimized time bounds (e.g., O(n log n), O(n√n), etc.).

4. **Edge Cases and Pitfalls**: The problem must include easy-to-make mistakes or tricky details (e.g., integer overflow, off-by-one errors, dictionary order traps, etc.).

5. **Implementation Difficulty**: The problem should involve complex state transitions or multi-dimensional data structures.

---

**Scenario Requirements:**

To fit the CodeForces-style, the problem proposed should be embedded in a self-consistent scenario.  
The story should naturally motivate the algorithmic constraints and connect logically to the selected features.

---

**Output Format:**

Return your result strictly in the following JSON format:
{{
    "selected_features": [ ... ],
    "question": (codeforces-stype problem statement)
}}

Provided Features:
{features_json}
"""

    # 使用LLM生成结果
    print(f"处理配置 {index}...")
    result = generate_problem(llm, prompt, sampling_params)
    
    # 解析结果为JSON
    try:
        if result:
            # 清理结果字符串
            cleaned_result = result.strip()
            
            # 移除markdown代码块标记
            if cleaned_result.startswith("```json") and cleaned_result.endswith("```"):
                cleaned_result = cleaned_result[7:-3].strip()
            elif cleaned_result.startswith("```") and cleaned_result.endswith("```"):
                cleaned_result = cleaned_result[3:-3].strip()
            
            # 处理可能的多行字符串
            if not cleaned_result.startswith("{"):
                json_start = cleaned_result.find("{")
                if json_start != -1:
                    cleaned_result = cleaned_result[json_start:]
            
            if not cleaned_result.endswith("}"):
                json_end = cleaned_result.rfind("}")
                if json_end != -1:
                    cleaned_result = cleaned_result[:json_end+1]
            
            parsed_result = json.loads(cleaned_result)
            # 添加索引到结果中
            parsed_result["config_index"] = index
            return parsed_result
        else:
            return {"config_index": index, "error": "生成问题失败"}
    except json.JSONDecodeError as e:
        # 如果结果不是有效的JSON，返回原始内容和错误信息
        return {
            "config_index": index,
            "raw_result": result,
            "error": f"无效的JSON格式: {str(e)}",
            "cleaned_result_attempt": cleaned_result if 'cleaned_result' in locals() else None
        }


def main(server_args, parallel_num, json_file_path, result_file):
    """主函数：加载模型并处理任务"""
    # 读取JSON文件
    with open(json_file_path, "r", encoding="utf-8") as f:
        all_configs = json.load(f)
    print(f"配置总数: {len(all_configs)}")
    
    # 创建结果池
    result_pool = []
    
    # 检查是否存在已生成的结果文件
    processed_indices = set()
    
    if os.path.exists(result_file):
        with open(result_file, "r", encoding="utf-8") as f:
            result_pool = json.load(f)
            processed_indices = {item["config_index"] for item in result_pool}
            print(f"从已有结果恢复: {len(result_pool)} 个已完成结果")
    
    # 过滤出未处理的任务
    task_pool = [config for config in all_configs if config["idx"] not in processed_indices]
    print(f"待处理任务数量: {len(task_pool)}")
    
    if not task_pool:
        print("所有任务已处理完毕。")
        return
    
    # 创建LLM实例
    print(f"正在加载模型...")
    llm = sgl.Engine(**dataclasses.asdict(server_args))
    print(f"模型加载完成")
    
    # 设置生成参数
    sampling_params = {"temperature": 0.6, "top_p": 0.95, "max_tokens": 32768}
    
    # 处理任务直到任务池为空
    while task_pool:
        # 取出当前批次的任务
        current_batch = task_pool[:parallel_num]
        
        with ThreadPoolExecutor(max_workers=parallel_num) as executor:
            # 提交当前批次的任务
            future_to_config = {
                executor.submit(process_features, config, config["idx"], llm, sampling_params): config
                for config in current_batch
            }
            
            # 跟踪成功处理的任务索引
            successful_indices = set()
            
            # 使用tqdm显示进度
            for future in tqdm(future_to_config, desc="处理中"):
                config = future_to_config[future]
                index = config["idx"]
                try:
                    result = future.result()
                    if not result or result.get("error"):
                        print(f"任务 {index} 处理失败，将在下一批次重试")
                    else:
                        result_pool.append(result)
                        successful_indices.add(index)
                        print(f"成功处理配置 {index}")
                except Exception as e:
                    print(f"处理配置 {index} 时出错: {str(e)}，将在下一批次重试")
        
        # 每个批次处理完成后，从任务池中移除成功的任务
        task_pool = [config for config in task_pool if config["idx"] not in successful_indices]
        
        # 保存完整结果
        with open(result_file, "w", encoding="utf-8") as f:
            json.dump(result_pool, f, ensure_ascii=False, indent=2)
        print(f"已保存当前结果，共 {len(result_pool)} 条数据，剩余 {len(task_pool)} 个任务待处理")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='离线生成竞赛问题')
    ServerArgs.add_cli_args(parser)  # 添加模型参数
    
    # 添加自定义参数
    parser.add_argument('--parallel_num', type=int, default=8, help='并行处理任务数量')
    parser.add_argument('--json_file_path', type=str, default="./features.json", help='输入JSON文件路径')
    parser.add_argument('--result_file', type=str, default="generated_problems.json", help='结果输出文件路径')
    
    args = parser.parse_args()
    server_args = ServerArgs.from_cli_args(args)
    
    main(
        server_args=server_args,
        parallel_num=args.parallel_num,
        json_file_path=args.json_file_path,
        result_file=args.result_file
    )