"""
使用方法:
python3 offline_answer_gen.py --model meta-llama/Llama-3.1-8B-Instruct --json_file_path ./问题数据.json --result_file ./生成结果.jsonl --batch_size 16
"""

import argparse
import dataclasses
import json
import os
import time
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

import sglang as sgl
from sglang.srt.server_args import ServerArgs


def generate_answers(llm, prompts, sampling_params):
    """使用本地模型生成答案"""
    outputs = llm.generate(prompts, sampling_params)
    return [output["text"] for output in outputs]


def process_questions_batch(sub_batch, llm, sampling_params):
    """批量处理多个问题配置"""
    # 提取问题和索引
    prompts = []
    indices = []
    
    for sample in sub_batch:
        # 提取问题配置
        question_json = json.dumps(sample["question"], ensure_ascii=False, indent=4)
        # 规范化换行符
        question_json = question_json.replace("\\n", "\n")
        
        # 构建提示
        prompt = f"""Now that you are a code expert, I have provided you with the QUESTION. Complete the problem with awesome code logic and give a richly commented analysis in the code of your answer. Include the necessary packages. Give out code implementation. Enclose the python code with ```python and ```.
- QUESTION -
{question_json}
"""
        prompts.append(prompt)
        indices.append(sample["config_index"])
    
    # 使用本地模型生成回答
    try:
        responses = generate_answers(llm, prompts, sampling_params)
        
        # 处理结果
        results = []
        for i, (response, index) in enumerate(zip(responses, indices)):
            if response is not None:
                result = {
                    "index": index,
                    "question": json.dumps(sub_batch[i]["question"], ensure_ascii=False, indent=4).replace("\\n", "\n"),
                    "answer": response,
                }
                results.append(result)
                print(f"成功处理配置 {index}")
            else:
                print(f"任务 {index} 处理失败")
        
        # 成功处理，返回结果
        return results, [sub_batch[i] for i in range(len(sub_batch)) if responses[i] is None]
                
    except Exception as e:
        print(f"批处理失败: {str(e)}")
    
    # 如果处理失败，返回空结果和所有配置
    print("批处理尝试失败，将任务标记为失败")
    return [], sub_batch


def save_results_jsonl(new_results, result_file):
    """将新结果以JSONL格式追加到文件末尾"""
    # 确保使用.jsonl扩展名
    jsonl_file = result_file
    if not jsonl_file.endswith('.jsonl'):
        jsonl_file = os.path.splitext(result_file)[0] + '.jsonl'
    
    try:
        # 以追加模式打开文件
        with open(jsonl_file, "a", encoding="utf-8") as f:
            for result in new_results:
                f.write(json.dumps(result, ensure_ascii=False) + "\n")
        print(f"已成功追加 {len(new_results)} 条结果到 {jsonl_file}")
        
    except Exception as e:
        print(f"追加结果时出错: {str(e)}")


def main(server_args, json_file_path, result_file, batch_size=16, num_workers=8):
    """主函数：加载模型并处理任务"""
    # 确保使用.jsonl扩展名
    jsonl_file = result_file
    if not jsonl_file.endswith('.jsonl'):
        jsonl_file = os.path.splitext(result_file)[0] + '.jsonl'
    
    # 读取JSON文件
    with open(json_file_path, "r", encoding="utf-8") as f:
        tasks = json.load(f)
    print(f"任务总数: {len(tasks)}")
    
    # 创建结果池
    result_pool = []
    
    # 检查是否存在已生成的结果文件
    processed_indices = set()
    
    # 从JSONL文件恢复已处理的结果
    if os.path.exists(jsonl_file):
        try:
            with open(jsonl_file, "r", encoding="utf-8") as f:
                for line in f:
                    try:
                        item = json.loads(line.strip())
                        result_pool.append(item)
                        processed_indices.add(item["index"])
                    except json.JSONDecodeError:
                        print(f"警告：跳过无效的JSONL行")
            print(f"从JSONL文件恢复: {len(result_pool)} 个已完成结果")
        except Exception as e:
            print(f"读取JSONL文件时出错: {str(e)}")
    
    # 过滤出未处理的任务
    task_pool = [config for config in tasks if config["config_index"] not in processed_indices]
    print(f"待处理任务数量: {len(task_pool)}")
    
    if not task_pool:
        print("所有任务已处理完毕。")
        return
    
    # 创建一个LLM实例
    print(f"正在加载模型...")
    llm = sgl.Engine(**dataclasses.asdict(server_args))
    print(f"模型加载完成")
    
    # 设置生成参数
    sampling_params = {"temperature": 0.6, "top_p": 0.95, "max_tokens": 32768}
    
    # 使用线程池处理任务
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = []
        
        # 将任务分批提交
        for i in range(0, len(task_pool), batch_size):
            sub_batch = task_pool[i:i+batch_size]
            if not sub_batch:
                continue
            
            future = executor.submit(process_questions_batch, sub_batch, llm, sampling_params)
            futures.append(future)
        
        # 处理结果
        total_tasks = len(task_pool)
        completed = 0
        failed = 0
        
        # 使用tqdm显示进度
        for future in tqdm(futures, desc="处理批次", total=len(futures)):
            try:
                results, failed_configs = future.result()
                
                # 保存成功的结果
                if results:
                    result_pool.extend(results)
                    save_results_jsonl(results, jsonl_file)
                    completed += len(results)
                
                # 统计失败任务
                failed += len(failed_configs)
                
                # 显示进度
                print(f"进度: {completed}/{total_tasks} "
                      f"({completed/total_tasks*100:.1f}%), 失败: {failed}")
                
            except Exception as e:
                print(f"处理批次时出错: {str(e)}")
    
    print(f"任务完成。成功处理: {completed}, 失败: {failed}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='离线生成代码答案')
    ServerArgs.add_cli_args(parser)  # 添加模型参数
    
    # 添加我们的自定义参数
    parser.add_argument('--json_file_path', type=str, default="./questions.json", help='输入JSON文件路径')
    parser.add_argument('--result_file', type=str, default="generated_answers.jsonl", help='结果输出文件路径')
    parser.add_argument('--batch_size', type=int, default=16, help='每个批次的任务数量')
    parser.add_argument('--num_workers', type=int, default=8, help='并行处理的工作线程数')
    
    args = parser.parse_args()
    server_args = ServerArgs.from_cli_args(args)
    
    main(
        server_args=server_args,
        json_file_path=args.json_file_path,
        result_file=args.result_file,
        batch_size=args.batch_size,
        num_workers=args.num_workers
    )