from utils import load_data_json_lines, get_truth_result_and_code, calculat_result_match_rate, get_truth_result_from_data, data_to_file, get_code_from_response, calculate_result_pass_rate, get_result_from_response, calculate_result_code_success_rate
from client import Client
from prompt import data_to_prompt
from action import Pytest_python, Gtest_cpp
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List
import time
import os


def parse_arguments():
    import argparse

    parser = argparse.ArgumentParser(description="Evaluate LLM in Generating Single Test Capabilities")
    parser.add_argument(
        "--debug",
        action="store_true",
        default=False,
        help="Run in debug mode with limited samples",
    )
    parser.add_argument(
        "--lite",
        action="store_true",
        default=False,
        help="Whether to run the lite version",
    )
    parser.add_argument(
        "--think",
        action="store_true",
        default=False,
        help="think type of LLM ",
    )
    parser.add_argument(
        "--server_url",
        type=str,
        default="http://localhost:8013",
        help="Server URL for the LLM API",
    )
    parser.add_argument(
        "--model_id",
        type=str,
        default="test_eb_model",
        help="Model ID to use for evaluation",
    )
    parser.add_argument(
        "--api_key",
        type=str,
        default=None,
        help="Model API key to use for evaluation",
    )
    parser.add_argument(
        "--concurrency", type=int, default=1, help="Number of concurrent requests"
    )
    parser.add_argument(
        "--log_dir", type=str, default="output", help="Directory for log files"
    )
    parser.add_argument(
        "--subtasks",
        nargs="+",
        default=["python"],
        help="Specific subtasks to run. Defaults to python. Optional: python and cpp."
    )
    parser.add_argument(
        "--top_p",
        type=float,
        default=0.0,
        help="Top-p sampling parameter for the model",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=1.0,
        help="Temperature parameter for the model",
    )
    parser.add_argument(
        "--max_tokens",
        type=int,
        default=2048,
        help="Max tokens parameter for the model",
    )
    parser.add_argument(
        "--codes",
        type=int,
        default=10,
        help="Number of codes per case. Optional: 1, 2, 5 and 10.",
    )
    parser.add_argument(
        "--output_type", type=str, default="code", help="Output type of LLM. Optional: code, code_part_problem, test_code, case and text"
    )

    args = parser.parse_args()
    return args


def process_prompts_concurrently(
    prompt_list: List[str], 
    client, 
    concurrency: int = 5,
    use_think: bool = False
) -> List[str]:
    """
    并发处理提示列表，返回响应列表
    
    参数:
    - prompt_list: 待处理的提示列表
    - client: 客户端实例，需包含interact方法
    - concurrency: 并发执行的数量
    
    返回:
    - response_list: 响应列表，与输入提示列表顺序对应
    """
    total_tasks = len(prompt_list)
    completed_tasks = 0
    
    def process_single_prompt(prompt: str) -> str:
        """处理单个提示的辅助函数"""
        nonlocal completed_tasks
        response = client.interact(prompt=prompt, use_system_prompt=True, use_thinking=use_think)
        
        # 增加已完成任务计数并打印进度
        completed_tasks += 1
        print(f"已完成交互任务: {completed_tasks}/{total_tasks}")
        
        return response
    
    # 使用线程池执行并发任务
    print(f"开始处理 {total_tasks} 个交互任务，并发数: {concurrency}")
    with ThreadPoolExecutor(max_workers=concurrency) as executor:
        # 提交所有任务并获取Future对象列表
        futures = [executor.submit(process_single_prompt, prompt) for prompt in prompt_list]
        
        # 按完成顺序获取结果（仅用于等待所有任务完成）
        for future in as_completed(futures):
            # 获取结果（异常会在这里抛出）
            future.result()
    
    # 按原始顺序收集结果
    response_list = [future.result() for future in futures]
    
    return response_list

def main():
    start_time = time.time()
    args = parse_arguments()

    # 加载数据集
    if "python" in args.subtasks:
        print("开始加载Python数据集")
        if args.debug:
            data = load_data_json_lines("./data/python_test2_5case.json")
        elif args.codes == 1:
            data = load_data_json_lines("./data/python_test400_1case.json")
        elif args.codes == 2:
            data = load_data_json_lines("./data/python_test400_2case.json")
        elif args.codes == 5:
            data = load_data_json_lines("./data/python_test400_5case.json")
        elif args.codes == 10:
            data = load_data_json_lines("./data/python_test400_10case.json")

        # 生成数据对应的prompt
        print("开始生成Python数据的prompt")
        prompt_list = data_to_prompt(data,"python",args.output_type)
    elif "cpp" in args.subtasks:
        print("开始加载C++数据集")
        if args.debug:
            data = load_data_json_lines("./data/c++_test2_5case.json")
        elif args.codes == 1:
            data = load_data_json_lines("./data/c++_test400_1case.json")
        elif args.codes == 2:
            data = load_data_json_lines("./data/c++_test400_2case.json")
        elif args.codes == 5:
            data = load_data_json_lines("./data/c++_test400_5case.json")
        elif args.codes == 10:
            data = load_data_json_lines("./data/c++_test400_10case.json")

        print("开始生成C++数据的prompt")
        prompt_list = data_to_prompt(data,"cpp",args.output_type)

    # 创建客户端实例
    print("开始创建客户端实例")
    client = Client(
        model_id=args.model_id,
        server_url=args.server_url,
        top_p=args.top_p,
        temperature=args.temperature,
        max_tokens=args.max_tokens,
        api_key=args.api_key
    )
    print("客户端实例创建完成")

    # 发送请求并获取响应
    print("开始并发处理LLM交互请求")

    response_all_list = process_prompts_concurrently(
        prompt_list=prompt_list,
        client=client,
        concurrency=args.concurrency,
        use_think=args.think
    )

    # import pdb; pdb.set_trace()
    if args.think:
        response_list = [sublist[0] for sublist in response_all_list]
        reason_response_list = [sublist[1] for sublist in response_all_list]
    else:
        response_list = response_all_list


    # 执行响应的操作
    print("开始执行操作获取结果")
    truth_result = get_truth_result_from_data(data, use_list=True)
    if "python" in args.subtasks:
        pytest_class = Pytest_python()
        if args.output_type == "test_code":
            action_log_list, predict_result, cov_dict = pytest_class.pytest_action_pair(data, response_list, print_error=True)
        elif args.output_type == "text":
            predict_result_str = get_code_from_response(response_list, start_str='<result>', end_str='</result>')
            action_log_list, predict_result, cov_dict = None, get_result_from_response(predict_result_str), None
        elif args.output_type == "case":
            action_log_list, predict_result, cov_dict = pytest_class.pytest_case_action(data, response_list, print_error=True, is_response=True)
        elif args.output_type == "code" or args.output_type == "code_part_problem":
            action_log_list, predict_result = pytest_class.pytest_action_pass(data, response_list, print_error=True)
    elif "cpp" in args.subtasks:
        gtest_class = Gtest_cpp()
        if args.output_type == "test_code":
            action_log_list, predict_result, cov_dict = gtest_class.gtest_action_pair(data, response_list, print_error=True)
        elif args.output_type == "text":
            predict_result_str = get_code_from_response(response_list, start_str='<result>', end_str='</result>')
            action_log_list, predict_result, cov_dict = None, get_result_from_response(predict_result_str), None
        elif args.output_type == "case":
            action_log_list, predict_result, cov_dict = gytest_class.gtest_case_action(data, response_list, print_error=True, is_response=True)
        elif args.output_type == "code" or args.output_type == "code_part_problem":
            action_log_list, predict_result = gtest_class.gtest_action_pass(data, response_list, print_error=True)
    
    if args.output_type == "code" or args.output_type == "code_part_problem":
        print("以下是通过情况:")
        print(predict_result)
    else:
        print("以下是真实结果：")
        print(truth_result)
        print("以下是预测结果：")
        print(predict_result)


    # 计算操作结果，得到最终的得分
    print("开始计算结果匹配率")
    if args.output_type == "code" or args.output_type == "code_part_problem":
        result_data = calculate_result_code_success_rate(predict_result)
    else:
        result_data = calculate_result_pass_rate(predict_result, truth_result, cov_dict)

    # 保存结果到json文件
    print("开始保存结果到json文件")
    save_path = f"./{args.log_dir}"
    non_test_code = True if args.output_type == "code" or args.output_type == "code_part_problem" else False
    data_to_file(data=data, prompt_list=prompt_list, response_list=response_all_list, action_log_list=action_log_list, output_dir=args.log_dir, result_data=result_data, predict_result=predict_result, non_test_code=non_test_code)
    end_time = time.time()
    print(f"总耗时: {end_time - start_time:.2f}秒")

if __name__ == "__main__":
    main()   