from utils.call_LLM import LLM_Caller_for_One_Thread
from utils.data_loader import Scenario_OSCE
from utils.tools import batch_process_parallel, extract_label
from prompt_Chinese import (
    Warning_from_too_many_auxiliary_exam,
    SYSTEM_PROMPT_judge_diagnosis_correct_level,
    USER_PROMPT_judge_diagnosis_correct_level,
    prompt_makeup_one_exam_item,
)
from config_test import args


def judge_correct_diagnosis_level(
    correct_diagnosis: str,
    diagnosis_to_be_judged: str,
    LLM_caller: LLM_Caller_for_One_Thread,
    other_info=None,
    model_name: str = "GLZ-Z1-Flash",
    level_list: list = ["完全正确", "临床可接受", "部分正确", "不正确"],
) -> tuple[bool, str]:
    try:
        if not other_info:
            other_info = "空"
        prefix = (
            f"最终诊断 (绝对正确)：\n{correct_diagnosis}\n\n"
            + f"医生诊断 (有待评估):\n{diagnosis_to_be_judged}\n\n"
            # + f"\n\n病历信息:\n{other_info}\n\n"
        )
        for _ in range(3):
            ans = LLM_caller.query_model_and_extract_label(
                model_str=model_name,
                prompt=USER_PROMPT_judge_diagnosis_correct_level.format(prefix),
                system_prompt=SYSTEM_PROMPT_judge_diagnosis_correct_level,
                role="Judge Diagnosis Level",
                ensure_label="answer",
            )
            if ans in level_list:
                correctness = ans in ["完全正确", "临床可接受"]
                return correctness, ans
    except Exception as e:
        print("Error in judge_correct_diagnosis_level:", e, flush=True)
    return False, level_list[-1]


def makeup_one_exam_result(
    item: str,
    scenario: Scenario_OSCE,
    model_name: str,
    LLM_caller: LLM_Caller_for_One_Thread = None,
):
    # 先直接检索dict
    if item.startswith("-"):
        item = item[1:].strip()
    if item in scenario.tests.get("extra", {}):
        ans = scenario.tests["extra"][item]
        if ans and isinstance(ans, str):
            return ans
    # 没有命中，让LLM回答
    patient_record = scenario.record_for_auxiliary_exam_agent()
    if not LLM_caller:
        LLM_caller = LLM_Caller_for_One_Thread()
    for _ in range(2):
        try:
            response = LLM_caller.query_model(
                model_str=model_name,
                prompt=prompt_makeup_one_exam_item.format(patient_record, item),
                system_prompt=None,
                role="MedTestAgent",
                ensure_label="answer",
            )
            ans = extract_label(response, "answer")
            hit = extract_label(response, "hit")
            hit = True if hit and hit.lower() == "yes" else False
            if ans:
                return f"- {ans}"
        except:
            pass
    return f"- {item}：结果丢失，请重新申请此检查。"


class MedTestAgent:
    """
    Provides medical test results to the doctor
    """

    def __init__(
        self,
        LLM_caller: LLM_Caller_for_One_Thread,
        scenario: Scenario_OSCE,
        model_name="gpt",
    ) -> None:
        self.scenario = scenario
        # LLM model_name for measurement agent
        self.model_name = model_name
        self.LLM_caller = LLM_caller

    def inference(self, question: str) -> str:
        item_list = [x for x in question.split("\n") if x.startswith("-")]
        if len(item_list) == 0:
            return "来自系统的警告：请你严格按照格式重新输出。"
        answer_of_each_item = batch_process_parallel(
            func=makeup_one_exam_result,
            args_list=[[item, self.scenario, self.model_name, self.LLM_caller] for item in item_list],
            num_processes=len(item_list),
        )
        # Combine the answers
        combined_answer = "\n".join(answer_of_each_item)
        return combined_answer


class Environment:
    """
    run one env
    """

    def __init__(self, scenario_id: int, scenario: Scenario_OSCE) -> None:
        self.scenario_id = scenario_id
        self.scenario = scenario
        self.LLM_caller = LLM_Caller_for_One_Thread(
            introduction_log=f"\n<hr>\n\n## {scenario_id} Environment\n"
        )
        self.current_stage = "Auxiliary Exam"
        self.med_test_agent = MedTestAgent(
            LLM_caller=self.LLM_caller,
            scenario=scenario,
            model_name=args.measurement_llm,
        )
        # for log
        self.dialog_list = []
        # for interaction tracking
        self.interaction_count = 0
        self.dialog_of_auxiliary_exam = ""
        self.final_diagnosis = ""
        self.correct_diagnosis = scenario.diagnosis_information()
        self.difficulty_level = scenario.difficulty_level()

    def get_initial_summary(self) -> str:
        record = "病史信息：\n" + self.scenario.patient_information() + "\n\n"
        record += "体格检查结果：\n" + self.scenario.short_physical_exam_information()
        summary = record
        self.dialog_of_auxiliary_exam = summary
        self.dialog_list.append(f"### {self.scenario_id}")
        self.dialog_list.append(summary + "\n")
        return summary

    def judge_diagnosis_correctness_and_level(
        self, diagnosis_to_be_judged: str
    ) -> tuple[bool, str]:
        return judge_correct_diagnosis_level(
            correct_diagnosis=self.scenario.diagnosis_information(),
            diagnosis_to_be_judged=diagnosis_to_be_judged,
            LLM_caller=self.LLM_caller,
            other_info=self.scenario.full_record(),
            model_name=args.judge_correctness_llm,
        )

    def judge_stage(self, question) -> bool:
        if any(
            [
                word in question.lower()
                for word in ["final diagnosis", "诊断结果", "核心诊断", "完整诊断"]
            ]
        ):
            self.current_stage = "Final Diagnosis"
        else:
            self.current_stage = "Auxiliary Exam"

    def response_to_doctor(self, question_from_doctor: str) -> str:
        if self.current_stage == "Final Diagnosis":
            return None

        self.dialog_list.append(f"**Doctor:** {question_from_doctor}")

        # judge the stage of doctor
        self.judge_stage(question_from_doctor)

        # stage Auxiliary Exam
        if self.current_stage == "Auxiliary Exam":
            self.interaction_count += 1
            if self.interaction_count >= args.auxiliary_exam_num:
                if self.interaction_count >= args.auxiliary_exam_num + 2:
                    raise ValueError("太多次辅检")
                self.answer_to_doctor = Warning_from_too_many_auxiliary_exam
                self.dialog_list.append(f"{self.answer_to_doctor}\n\n")
                return self.answer_to_doctor

            self.answer_to_doctor = self.med_test_agent.inference(question_from_doctor)
            self.dialog_list.append(f"**Results:** {self.answer_to_doctor}")

        else:
            assert self.current_stage == "Final Diagnosis"
            # record the dialog of Auxiliary exam
            self.dialog_of_auxiliary_exam = "\n".join(self.dialog_list[1:-1])
            self.final_diagnosis = question_from_doctor
            return None  # indicates the end of interaction

        return self.answer_to_doctor

    def info_dict(self) -> dict:
        correctness, correctness_level = self.judge_diagnosis_correctness_and_level(
            self.final_diagnosis
        )
        # Update log list
        self.dialog_list.append(
            f"\nScene {self.scenario_id}, Correct answer: **{self.correct_diagnosis}**"
            + f"difficulty: {self.difficulty_level}",
        )
        self.dialog_list.append(
            f"0. Final Diagnosis was \n```\n{self.final_diagnosis}\n```\n"
            + ("**CORRECT**" if correctness else "**INCORRECT**")
            + f"  {correctness_level}\n\n"
        )
        return {
            "scenario_id": self.scenario_id,
            "correct_diagnosis": self.correct_diagnosis,
            "diagnosis": [
                self.final_diagnosis,
            ],
            "correctness": correctness,
            "correctness_level": correctness_level,
            "difficulty": self.difficulty_level,
            "interaction_count": [
                self.interaction_count,
            ],
            "print_dialog": self.dialog_list,
            "env_detailed_log": self.LLM_caller.LLM_log_list,
            "dialog_no_diagnosis": self.dialog_of_auxiliary_exam,
        }
