import os
import sys

sys.path.append("../")
from utils.call_LLM import (
    LLM_Caller_for_One_Thread,
    launch_vllm_server,
    kill_vllm_server,
)
from utils.io_func import read_json, write_json
from utils.tools import batch_process_parallel
from datetime import datetime
from utils.data_loader import Scenario_OSCE_Loader, Scenario_OSCE
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed, ThreadPoolExecutor
from environment_real_only_auxiliary import Environment, args
from doctor_factory import create_doctor_agent
from modules import diagnose_from_record_and_judge_correctness, judge_topk

global_dict_from_sys_prompt_to_context_id = {}


##########################################################################################################################################


def run_one_scenario(_scenario_id: int, scenario: Scenario_OSCE, _repeat: int = 1):
    try:
        # Initialize scenario runner
        env = Environment(scenario_id=_scenario_id, scenario=scenario)
        LLM_caller = LLM_Caller_for_One_Thread(
            introduction_log=f"\n<hr>\n\n## {_scenario_id} Doctor\n",
            dict_from_sys_prompt_to_context_id=global_dict_from_sys_prompt_to_context_id.copy(),
        )

        # Initialize doctor agent
        doctor_agent = create_doctor_agent(
            class_name=args.doctor_class,
            LLM_caller=LLM_caller,
            model_name=args.doctor_llm,
            current_summary=env.get_initial_summary(),
        )

        # Start interaction
        question_from_doctor = ""
        answer_to_doctor = None

        ##########################################################################################################################################

        while True:
            question_from_doctor = doctor_agent.inference(answer_to_doctor)
            answer_to_doctor = env.response_to_doctor(question_from_doctor)
            if answer_to_doctor is None:
                break

        ##########################################################################################################################################

        # Prepare info dict
        info_dict = env.info_dict()
        info_dict.update(
            {
                "doctor_detailed_log": LLM_caller.LLM_log_list,
                "top3_diagnosis": doctor_agent.get_top3_diagnosis(),
                "repeat_id": _repeat,
            }
        )

        return info_dict, _scenario_id, _repeat
    except Exception as e:
        print("Error:", _scenario_id, _repeat, e)
        return None, _scenario_id, _repeat


def main():
    scenario_loader = Scenario_OSCE_Loader(args.dataset_path)
    results_list = []

    args.num_scenarios = min(args.num_scenarios, scenario_loader.num_scenarios)
    total_num_scenarios = args.num_scenarios
    pool = ProcessPoolExecutor if "baichuan" not in args.doctor_llm.lower() else ThreadPoolExecutor

    with pool(max_workers=args.parallel_thread_num) as executor:
        # with ProcessPoolExecutor(max_workers=args.parallel_thread_num) as executor:
        # 提交所有任务
        future_to_scenario = {}
        for _repeat in range(args.repeat_cnt):
            for _scenario_id in range(total_num_scenarios):
                future = executor.submit(
                    run_one_scenario,
                    _scenario_id,
                    scenario_loader.get_scenario(_scenario_id),
                    _repeat,
                )
                future_to_scenario[future] = _scenario_id

        # 处理完成的任务
        for future in as_completed(future_to_scenario):
            # 获取结果
            info_dict, _scenario_id, _repeat = future.result()
            if info_dict:  # not None
                print(f"Scene {_scenario_id}-{_repeat}: done", flush=True)
                results_list.append(info_dict)
            else:
                print(f"Scene {_scenario_id}-{_repeat}: skipped", flush=True)

    results_list = sorted(results_list, key=lambda x: x["scenario_id"])
    os.makedirs("detailed_log", exist_ok=True)
    write_json(
        {"settings": args.to_dict(), "results_list": results_list},
        f"detailed_log/{args.eval_id}-full-log.json",
    )


def eval():
    """
    评测
    """
    saved_info = read_json(f"detailed_log/{args.eval_id}-full-log.json")
    settings = saved_info["settings"]
    results_list = saved_info["results_list"]
    total_num_scenarios = settings["num_scenarios"] * settings["repeat_cnt"]
    scenario_loader = Scenario_OSCE_Loader(args.dataset_path)

    args.parallel_thread_num = 100

    # 打印对话，不含思维链
    # 1. 模拟问诊过程
    print("\n".join([_log for res in results_list for _log in res["print_dialog"]]))

    ##########
    # 打印问诊准确率等信息
    if True:
        # 不同阶段的准确率和次数
        total_correct_count_final = sum(res["correctness"] for res in results_list)
        total_inquiry_count_auxiliary = sum(
            res["interaction_count"][0] for res in results_list
        )
        print(
            f"Final: score= {total_correct_count_final / total_num_scenarios * 100:.2f}",
            f"#inquiry= {(total_inquiry_count_auxiliary)/ total_num_scenarios:.2f}",
        )
        # 不同难度的准确率和次数 #correctness, #cases, #avg-inquiry
        level_dict = {level: (0, 0, 0) for level in range(1, 6)}
        for item in results_list:
            level = item["difficulty"]
            correct, num_cases, inquiry_cnt = level_dict[level]
            level_dict[level] = (
                correct + item["correctness"],
                num_cases + 1,
                inquiry_cnt + sum(item["interaction_count"]),
            )
        for level in level_dict:
            correct, num_cases, inquiry_cnt = level_dict[level]
            if num_cases == 0:
                num_cases = 0.1  # 避免除零错误
            print(
                f"difficulty-level-{level}",
                f"accuracy={correct / num_cases * 100:.3f}",
                f"#inquiry={inquiry_cnt / num_cases:.2f}",
                f"#cases={int(num_cases)}",
            )

        # 统计4级正确标准
        level_cnt = {"完全正确": 0, "临床可接受": 0, "部分正确": 0, "不正确": 0}
        for res in results_list:
            level = res["correctness_level"]
            level_cnt[level] = level_cnt.get(level, 0) + 1
        for level in ["完全正确", "临床可接受", "部分正确", "不正确"]:
            print(
                f"{level}: {level_cnt.get(level,0)} cases, "
                f"{level_cnt.get(level,0)/total_num_scenarios*100:.2f}%"
            )

    ############
    # 让另一个医生重新诊断
    if True:
        res_rediagnose_list = batch_process_parallel(
            func=diagnose_from_record_and_judge_correctness,
            args_list=[
                (
                    res["dialog_no_diagnosis"],
                    res["correct_diagnosis"],
                    args.rediagnosis_llm,
                    None,
                    scenario_loader.get_scenario(res["scenario_id"]).full_record(),
                )
                for res in results_list
            ],
            num_processes=args.parallel_thread_num,
            use_tqdm=False,
        )
        for idx in range(len(results_list)):
            results_list[idx].update({"rediagnosis": res_rediagnose_list[idx]})
        accuracy = (
            sum([x["correctness"] for x in res_rediagnose_list]) / total_num_scenarios
        )
        print("Re-diagnosis using", args.rediagnosis_llm)
        print(f"Re-diagnosis Accuracy: {accuracy*100:.2f}%")

    ############
    # top-k
    if True:
        res_topk_list = batch_process_parallel(
            func=judge_topk,
            args_list=[
                (
                    res["top3_diagnosis"],
                    res["correct_diagnosis"],
                    args.rediagnosis_llm,
                    None,
                    scenario_loader.get_scenario(res["scenario_id"]).full_record(),
                )
                for res in results_list
            ],
            num_processes=args.parallel_thread_num,
            use_tqdm=False,
        )
        for idx in range(len(results_list)):
            results_list[idx].update({"top3_diagnosis": res_topk_list[idx]})

        top1_correct_count = sum(
            res["top3_diagnosis"][0]["correctness"] for res in results_list
        )
        top2_correct_count = sum(
            any(
                [
                    res["top3_diagnosis"][0]["correctness"],
                    res["top3_diagnosis"][1]["correctness"],
                ]
            )
            for res in results_list
        )
        top3_correct_count = sum(
            any(
                [
                    res["top3_diagnosis"][0]["correctness"],
                    res["top3_diagnosis"][1]["correctness"],
                    res["top3_diagnosis"][2]["correctness"],
                ]
            )
            for res in results_list
        )
        print(
            f"Top-1 Accuracy: {top1_correct_count / total_num_scenarios * 100:.2f}%",
            f"\nTop-2 Accuracy: {top2_correct_count / total_num_scenarios * 100:.2f}%",
            f"\nTop-3 Accuracy: {top3_correct_count / total_num_scenarios * 100:.2f}%",
        )

    write_json(
        {
            "settings": args.to_dict(),
            "results_list": results_list,
        },
        f"detailed_log/{args.eval_id}-full-log.json",
    )

    # 计算辅助检查项目个数
    cnt_aux_item_list = []
    for res in results_list:
        dialog_list = res["print_dialog"]
        cnt = 0
        for s in dialog_list:
            if "**Doctor:** 请求进行以下辅助检查：" not in s:
                continue
            s_lines = s.split("以下辅助检查：")[-1].split("\n")
            for line in s_lines:
                if len(line) > 2 and line[0] == "-":
                    cnt += 1
        cnt_aux_item_list.append(cnt)
    print(f"#auxiliary_items= {sum(cnt_aux_item_list) / len(cnt_aux_item_list):.2f}")



if __name__ == "__main__":
    if not args.eval_id:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        args.eval_id = f"{current_time}-{Path(args.dataset_path).stem}--doctor={args.doctor_llm.split('/')[-1]}"
        print(args.eval_id, flush=True)

        if args.doctor_llm_path:
            launch_vllm_server(
                model_path=args.doctor_llm_path, model_name=args.doctor_llm
            )
        try:
            main()
        except:
            pass
    else:
        print("Only eval() from", args.eval_id, flush=True)

    if args.evaluate:
        eval()
