from environment_real_only_auxiliary import (
    Warning_from_too_many_auxiliary_exam,
    LLM_Caller_for_One_Thread,
)
from .interface import BaseDoctorAgent

long_diagnosis_system_prompt = (
    "你是一名专业医生，负责根据提供的病人信息，进行鉴别诊断。"
)
diagnosis_user_prompt = "<病历>\n{}\n</病历>"

action_inquiry_physical_system_prompt = (
    "你负责根据提供的病人信息，决定医生下一步要做的问诊或体格检查。"
)
action_inquiry_physical_user_prompt = (
    "<病历>\n{}\n</病历>\n\n<上一步动作>\n{}\n</上一步动作>"
)

action_auxiliary_item_system_prompt = "你是一名专业医生，负责根据提供的病人信息和鉴别诊断，决定医生下一步要进行哪些辅助检查"
action_exam_or_diagnosis = "你是一名专业医生，负责根据提供的病人信息和鉴别诊断，决定医生下一步是继续进行辅助检查，还是直接给出最终诊断。"
action_user_prompt = "<病历>\n{}\n</病历>\n\n<鉴别诊断>\n{}\n</鉴别诊断>"

start_auxiliary_key_words = "终止体格检查，开始进入辅助检查阶段"


class DoctorAgent_SFT_1221(BaseDoctorAgent):
    """
    只适用于本课题的微调LLM；不可用于其他基于prompt的LLM
    """

    def __init__(
        self,
        LLM_caller: LLM_Caller_for_One_Thread,
        model_name="gpt4",
        current_summary="",
    ) -> None:
        super().__init__(LLM_caller, model_name, current_summary)
        self.generation_kwargs = {"max_tokens": 12288, "temperature": 0.7}
        self.top3_diagnosis = [None] * 3

    def long_diagnosis(self, record_summary: str) -> list:
        """
        只诊断，返回top-k Diagnosis
        """
        # long 鉴别诊断
        messages = [
            {"role": "system", "content": long_diagnosis_system_prompt},
            {"role": "user", "content": diagnosis_user_prompt.format(record_summary)},
        ]
        for _ in range(5):
            raw_answer = self.LLM_caller.query_model(
                model_str=self.model_name,
                messages=messages,
                role="long diagnosis",
                ensure_label=None,
                generation_kwargs=self.generation_kwargs,
            )
            diagnosis_str = raw_answer.split("</reason>")[-1].strip()
            if diagnosis_str and len(diagnosis_str) <= 1000:
                break
            else:
                print(diagnosis_str, "不符合格式：鉴别")
                diagnosis_str = None
        if not diagnosis_str:
            raise ValueError("辅检时鉴别诊断失败")

        # 地毯式验证
        diagnosis_list = [d.strip() for d in diagnosis_str.split("\n") if d.strip()]

        # 记录最新的top-3
        self.top3_diagnosis = (diagnosis_list + [None] * 3)[:3]
        return diagnosis_list

    def decide_exam_or_diagnosis(self, summary: str, diagnosis_str: str) -> str:
        messages = [
            {"role": "system", "content": action_exam_or_diagnosis},
            {
                "role": "user",
                "content": action_user_prompt.format(summary, diagnosis_str),
            },
        ]
        valid_answer = False
        for _ in range(5):
            raw_answer = self.LLM_caller.query_model(
                model_str=self.model_name,
                messages=messages,
                role="action auxiliary",
                ensure_label=None,
                generation_kwargs=self.generation_kwargs,
            )
            decision_answer = raw_answer.split("</reason>")[-1].strip()
            if decision_answer.startswith("您的诊断结果为"):  # 选择诊断
                valid_answer = True
                break
            if decision_answer == "继续辅助检查":
                valid_answer = True
                break
            print(decision_answer, "不符合格式：决定")
        if not valid_answer:
            raise ValueError("决定是否继续检查 失败")
        return decision_answer

    def recommend_auxiliary(self, summary: str, diagnosis_str: str) -> str:
        messages = [
            {"role": "system", "content": action_auxiliary_item_system_prompt},
            {
                "role": "user",
                "content": action_user_prompt.format(summary, diagnosis_str),
            },
        ]
        for _ in range(5):
            raw_answer = self.LLM_caller.query_model(
                model_str=self.model_name,
                messages=messages,
                role="action auxiliary",
                ensure_label=None,
                generation_kwargs=self.generation_kwargs,
            )
            action_str = raw_answer.split("</reason>")[-1].strip()
            if action_str and len(action_str) <= 800:
                break
            else:
                print(action_str, "不符合格式：辅检")
                action_str = None
        if not action_str:
            raise ValueError("辅检动作失败")
        return action_str

    def _get_next_action(self, patient_answer: str, summary: str) -> str:
        # long 鉴别诊断
        diagnosis_list = self.long_diagnosis(record_summary=summary)
        diagnosis_str = "\n".join(diagnosis_list)

        decision = self.decide_exam_or_diagnosis(
            summary=summary, diagnosis_str=diagnosis_str
        )
        if decision.startswith("您的诊断结果为"):
            return decision
        if patient_answer == Warning_from_too_many_auxiliary_exam:
            # 强制停止辅助检查，给出一个诊断
            return f"您的诊断结果为：{diagnosis_list[0]}"
        # 下一步动作
        action_str = self.recommend_auxiliary(summary, diagnosis_str)
        return action_str

    def get_top3_diagnosis(self) -> list:
        return self.top3_diagnosis
