import os
import sys

sys.path.append("../")
from utils.call_LLM import (
    LLM_Caller_for_One_Thread,
    launch_vllm_server,
    kill_vllm_server,
)
from utils.io_func import read_json, write_json
from utils.tools import batch_process_parallel
from datetime import datetime
from logger import Logger
from utils.data_loader import Scenario_OSCE_Loader, Scenario_OSCE
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
from environment_real_only_auxiliary import args
from modules import judge_topk
import inspect
from abc import ABC, abstractmethod


def summarize_record(
    scenario: Scenario_OSCE, model_name: str, LLM_caller: LLM_Caller_for_One_Thread
):
    return scenario.full_record()


prompt_diagnosis_from_record = """你是一名医学专家，你的任务是根据提供的患者病历，推理出top-5鉴别诊断列表。

**核心要求：**
- **诊断排序**：诊断列表按可能性由高到低排列。
- **诊断完整性**：每个诊断都应是完整的，应具体明确（例如：使用“右下叶肺炎”而非“肺炎”；“冠心病不稳定型心绞痛”而非“心脏病”）；可包含主要疾病和相关的并发症/合并症（例如：2型糖尿病 合并 社区获得性肺炎）。
- **诊断竞争性**：列表中的各项诊断应该是相互竞争的备选方案（即鉴别诊断）。**不要将一个统一病理过程的不同方面拆分成独立的条目**（如将“社区获得性肺炎”和“发热”分别列为两个诊断）。
- **聚焦诊断**：你的回答应专注于诊断推理过程和最终的诊断列表。**严禁**提供任何治疗方案、用药建议或健康指导，也不要包含病人的检查结果等信息。
- **诊断个数**：允许鉴别诊断个数不足5个。


以下是病人的信息：
<病历>
{}
</病历>


输出格式：
逐步的分析...
<answer>
诊断1
诊断2
...
</answer>

输出示例：
...（逐步的分析）
<answer>
结核性脑膜炎/脑膜脑炎，伴有社区获得性肺炎
鼻窦旁脓肿，并发结核性全身感染
</answer>


现在请先给出逐步的分析，然后输出若干个相互竞争、完整的诊断方案，不要给出其他无内容。
"""


class BaseDoctorAgent(ABC):
    """
    医生Agent的基类。
    封装了通用的初始化、对话摘要更新逻辑。
    """

    def __init__(
        self,
        LLM_caller: LLM_Caller_for_One_Thread,
        model_name="gpt4",
    ) -> None:
        self.model_name = model_name
        self.LLM_caller = LLM_caller
        self.top3_diagnosis = [None] * 3

    def get_top3_diagnosis(self, record_summary: str) -> list:
        return self.top3_diagnosis


class Doctor_API(BaseDoctorAgent):
    def __init__(
        self,
        LLM_caller: LLM_Caller_for_One_Thread,
        model_name="gpt4",
    ) -> None:
        super().__init__(LLM_caller, model_name)

    def get_top3_diagnosis(self, record_summary: str) -> list:
        diag_list_str = self.LLM_caller.query_model_and_extract_label(
            model_str=self.model_name,
            prompt=prompt_diagnosis_from_record.format(record_summary),
            system_prompt=None,
            role="Doctor",
            ensure_label="answer",
            try_cnt=10,
        )
        diag_lines = diag_list_str.splitlines()
        top3 = []
        for line in diag_lines:
            line = line.strip()
            if line:
                if line.startswith("-"):
                    diag = line[1:].strip()
                elif line[0].isdigit() and "." in line:
                    diag = line.split(".", 1)[1].strip()
                else:
                    diag = line
                top3.append(diag)
            if len(top3) >= 3:
                break
        self.top3_diagnosis = (top3 + [None] * 3)[:3]
        return self.top3_diagnosis


class Doctor_SFT(BaseDoctorAgent):
    """
    只适用于本课题的微调LLM；不可用于其他基于prompt的LLM
    """

    def __init__(
        self,
        LLM_caller: LLM_Caller_for_One_Thread,
        model_name="gpt4",
        max_auxiliary_num=8,
        step_by_step_verify=False,
    ) -> None:
        super().__init__(LLM_caller, model_name)
        self.max_auxiliary_num = max_auxiliary_num
        self.step_by_step_verify = step_by_step_verify
        # 用于记录对话
        self.conversation_history = []
        self.enable_long_diagnosis = False
        self.generation_kwargs = {
            "max_tokens": 7168,
            "temperature": 0.7,
        }

    def get_top3_diagnosis(self, record_summary: str) -> list:
        """
        只诊断，返回top-k Diagnosis
        """
        long_diagnosis_system_prompt = (
            "你是一名专业医生，负责根据提供的病人信息，进行鉴别诊断。"
        )
        diagnosis_user_prompt = "<病历>\n{}\n</病历>"

        verify_system_prompt = (
            "你负责验证指定的诊断是否与病人的信息矛盾。若无矛盾，输出yes"
        )
        verify_user_prompt = "<病历>\n{}\n</病历>\n\n<诊断>\n{}\n</诊断>"
        # long 鉴别诊断
        messages = [
            {"role": "system", "content": long_diagnosis_system_prompt},
            {"role": "user", "content": diagnosis_user_prompt.format(record_summary)},
        ]
        for _ in range(3):
            raw_answer = self.LLM_caller.query_model(
                model_str=self.model_name,
                messages=messages,
                role="long diagnosis",
                ensure_label=None,
                generation_kwargs=self.generation_kwargs,
            )
            diagnosis_str = raw_answer.split("</reason>")[-1].strip()
            if diagnosis_str and len(diagnosis_str) < 800:
                break
            else:
                diagnosis_str = None
        if not diagnosis_str:
            raise ValueError("辅检时鉴别诊断失败")

        # 地毯式验证
        diagnosis_list = [d.strip() for d in diagnosis_str.split("\n") if d.strip()]
        verified_diagnosis_list = []
        wrong_diagnosis_list = []
        for d in diagnosis_list:
            messages = [
                {"role": "system", "content": verify_system_prompt},
                {
                    "role": "user",
                    "content": verify_user_prompt.format(record_summary, d),
                },
            ]
            for _ in range(3):
                if self.step_by_step_verify:
                    raw_answer = self.LLM_caller.query_model(
                        model_str=self.model_name,
                        messages=messages,
                        role="verify diagnosis",
                        ensure_label=None,
                        generation_kwargs=self.generation_kwargs,
                    )
                else:
                    raw_answer = "<reason>无需地毯式验证</reason>yes"
                verify_result = raw_answer.split("</reason>")[-1].strip()
                if verify_result:
                    if verify_result.startswith("yes"):
                        verified_diagnosis_list.append(d)
                    else:
                        wrong_diagnosis_list.append(d)
                    break

        # 记录最新的top-3
        self.top3_diagnosis = []
        if len(verified_diagnosis_list) > 0:
            self.top3_diagnosis = verified_diagnosis_list + wrong_diagnosis_list
            # 只保留3个
            self.top3_diagnosis = (self.top3_diagnosis + [None] * 3)[:3]
        else:
            # 都没通过验证
            self.top3_diagnosis = diagnosis_list
            # 只保留3个
            self.top3_diagnosis = (self.top3_diagnosis + [None] * 3)[:3]

        return self.top3_diagnosis


def create_doctor_agent(
    class_name: str,
    LLM_caller: LLM_Caller_for_One_Thread,
    model_name: str = "gpt4",
) -> "BaseDoctorAgent":
    """
    根据类名字符串动态查找并创建医生Agent实例。

    Args:
        class_name (str): 目标类的名称，例如 "DoctorAgent_Diagnosis_Exam_Inherited"。

    Returns:
        BaseDoctorAgent: 对应类的实例。

    Raises:
        ValueError: 如果找不到指定的类名，或者该类不是BaseDoctorAgent的有效子类。
    """
    agent_class = globals().get(class_name)

    # --- 安全性校验 ---
    # 1. 检查是否找到了对应的名称，并且它确实是一个类
    if not agent_class or not inspect.isclass(agent_class):
        raise ValueError(f"错误：未在当前作用域中找到名为 '{class_name}' 的类。")

    # 2. 检查这个类是否是 BaseDoctorAgent 的子类（但不是基类本身）
    # 这是为了确保我们实例化的对象拥有正确的接口 (如 inference 方法)
    if not issubclass(agent_class, BaseDoctorAgent) or agent_class is BaseDoctorAgent:
        raise TypeError(f"错误：类 '{class_name}' 不是 BaseDoctorAgent 的有效子类。")

    # --- 实例化并返回 ---
    # 如果所有检查都通过，就安全地创建实例
    return agent_class(
        LLM_caller=LLM_caller,
        model_name=model_name,
    )


def run_one_scenario(_scenario_id: int, scenario: Scenario_OSCE, _repeat: int = 1):
    """
    使用单个病历进行评测；本函数可以并发调用
    """
    try:
        # Initialize scenario runner
        LLM_caller = LLM_Caller_for_One_Thread(
            introduction_log=f"\n<hr>\n\n## {_scenario_id} Doctor\n"
        )

        # Initialize doctor agent
        doctor_agent = create_doctor_agent(
            args.doctor_class, LLM_caller=LLM_caller, model_name=args.doctor_llm
        )

        # Static Diagnosis: from full record summary
        record_summary = summarize_record(
            scenario=scenario, model_name=args.measurement_llm, LLM_caller=LLM_caller
        )
        static_top3_diagnosis_list = doctor_agent.get_top3_diagnosis(record_summary)
        # top3_diagnosis
        correctness_list = judge_topk(
            topk_diagnosis=static_top3_diagnosis_list,
            correct_diagnosis=scenario.diagnosis_information(),
            model_name=args.judge_correctness_llm,
            LLM_caller=LLM_caller,
            other_info=scenario.full_record(),
        )

        dialog_list = [
            f"\n### {_scenario_id}\n",
            f"{record_summary}\n",
            f"Correct Diagnosis: {scenario.diagnosis_information()}\n",
        ]
        for idx, item in enumerate(correctness_list):
            dialog_list.append(
                f"Diagnosis {idx+1}: {item['diagnosis']}\n"
                + ("**CORRECT**" if item["correctness"] else "**INCORRECT**")
                + f" {item['level']}\n"
            )

        info_dict = {
            "scenario_id": _scenario_id,
            "correct_diagnosis": scenario.diagnosis_information(),
            "difficulty": scenario.difficulty_level(),
            "print_dialog": dialog_list,
            "top3_diagnosis": correctness_list,
            "detailed_log": LLM_caller.LLM_log_list,
        }

        return info_dict, _scenario_id, _repeat
    except Exception as e:
        print("Error:", _scenario_id, _repeat, e, flush=True)
        return None, _scenario_id, _repeat


def main():
    """
    运行，然后保存结果；暂不进行指标评估
    """
    scenario_loader = Scenario_OSCE_Loader(args.dataset_path)
    results_list = []

    args.num_scenarios = min(args.num_scenarios, scenario_loader.num_scenarios)
    total_num_scenarios = args.num_scenarios

    # 创建进程池并提交任务
    with ProcessPoolExecutor(max_workers=args.parallel_thread_num) as executor:
        # 提交所有任务
        future_to_scenario = {}
        for _repeat in range(args.repeat_cnt):
            for _scenario_id in range(total_num_scenarios):
                future = executor.submit(
                    run_one_scenario,
                    _scenario_id,
                    scenario_loader.get_scenario(_scenario_id),
                    _repeat,
                )
                future_to_scenario[future] = _scenario_id

        # 处理完成的任务
        for future in as_completed(future_to_scenario):
            # 获取结果
            info_dict, _scenario_id, _repeat = future.result()
            if info_dict:  # not None
                print(f"Scene {_scenario_id}-{_repeat}: done", flush=True)
                results_list.append(info_dict)
            else:
                print(f"Scene {_scenario_id}-{_repeat}: skipped", flush=True)

    results_list = sorted(results_list, key=lambda x: x["scenario_id"])
    os.makedirs("detailed_log_static", exist_ok=True)
    write_json(
        {"settings": args.to_dict(), "results_list": results_list},
        f"detailed_log_static/{args.eval_id}-full-log.json",
    )


def eval():
    """
    评测
    """
    saved_info = read_json(f"detailed_log_static/{args.eval_id}-full-log.json")
    settings = saved_info["settings"]
    results_list = saved_info["results_list"]
    total_num_scenarios = settings["num_scenarios"] * settings["repeat_cnt"]

    args.parallel_thread_num = 100

    # 打印对话，不含思维链
    # 1. 模拟问诊过程
    print("\n".join([_log for res in results_list for _log in res["print_dialog"]]))

    if True:
        # 使用4级正确标准重新判断
        level_cnt = {"完全正确": 0, "临床可接受": 0, "部分正确": 0, "不正确": 0}
        for res in results_list:
            level = res["top3_diagnosis"][0]["level"]
            level_cnt[level] = level_cnt.get(level, 0) + 1
        for level in ["完全正确", "临床可接受", "部分正确", "不正确"]:
            print(
                f"{level}: {level_cnt.get(level,0)} cases, "
                f"{level_cnt.get(level,0)/total_num_scenarios*100:.2f}%"
            )

    ############
    # top-k
    if True:
        top1_correct_count = sum(
            res["top3_diagnosis"][0]["correctness"] for res in results_list
        )
        top2_correct_count = sum(
            any(
                [
                    res["top3_diagnosis"][0]["correctness"],
                    res["top3_diagnosis"][1]["correctness"],
                ]
            )
            for res in results_list
        )
        top3_correct_count = sum(
            any(
                [
                    res["top3_diagnosis"][0]["correctness"],
                    res["top3_diagnosis"][1]["correctness"],
                    res["top3_diagnosis"][2]["correctness"],
                ]
            )
            for res in results_list
        )
        print(
            f"Top-1 Accuracy: {top1_correct_count / total_num_scenarios * 100:.2f}%",
            f"\nTop-2 Accuracy: {top2_correct_count / total_num_scenarios * 100:.2f}%",
            f"\nTop-3 Accuracy: {top3_correct_count / total_num_scenarios * 100:.2f}%",
        )

    write_json(
        {
            "settings": args.to_dict(),
            "results_list": results_list,
        },
        f"detailed_log_static/{args.eval_id}-full-log.json",
    )


if __name__ == "__main__":
    if not args.eval_id:
        current_time = datetime.now().strftime("%Y%m%d-%H%M")
        args.eval_id = f"static-{current_time}-{Path(args.dataset_path).stem}--doctor={args.doctor_llm.split('/')[-1]}"
        print(args.eval_id, flush=True)

        if args.doctor_llm_path:
            launch_vllm_server(
                model_path=args.doctor_llm_path, model_name=args.doctor_llm
            )
        try:
            main()
        except Exception as e:
            print(e)
    else:
        print("Only eval() from", args.eval_id, flush=True)

    # 评测指标
    if args.evaluate:
        eval()
