try:
    from .environment_real_only_auxiliary import (
        judge_correct_diagnosis_level,
        LLM_Caller_for_One_Thread,
    )
except:
    from auxiliary_benchmark.environment_real_only_auxiliary import (
        judge_correct_diagnosis_level,
        LLM_Caller_for_One_Thread,
    )


prompt_diagnosis_from_record = """你是一名医学专家，你的任务是根据提供的患者病历，推理出top-5鉴别诊断列表。

**核心要求：**
- **诊断排序**：诊断列表按可能性由高到低排列。
- **诊断完整性**：每个诊断都应是完整的，应具体明确（例如：使用“右下叶肺炎”而非“肺炎”；“冠心病不稳定型心绞痛”而非“心脏病”）；可包含主要疾病和相关的并发症/合并症（例如：2型糖尿病 合并 社区获得性肺炎）。
- **诊断竞争性**：列表中的各项诊断应该是相互竞争的备选方案（即鉴别诊断）。**不要将一个统一病理过程的不同方面拆分成独立的条目**（如将“社区获得性肺炎”和“发热”分别列为两个诊断）。
- **聚焦诊断**：你的回答应专注于诊断推理过程和最终的诊断列表。**严禁**提供任何治疗方案、用药建议或健康指导，也不要包含病人的检查结果等信息。
- **诊断个数**：允许鉴别诊断个数不足5个。


以下是病人的信息：
<病历>
{}
</病历>


输出格式：
逐步的分析...
<answer>
诊断1
诊断2
...
</answer>

输出示例：
...（逐步的分析）
<answer>
结核性脑膜炎/脑膜脑炎，伴有社区获得性肺炎
鼻窦旁脓肿，并发结核性全身感染
</answer>


现在请先给出逐步的分析，然后输出若干个相互竞争、完整的诊断方案，不要给出其他无内容。
"""


def diagnose_from_record(
    record: str, model_name: str, LLM_caller: LLM_Caller_for_One_Thread = None
) -> str:
    try:
        if LLM_caller is None:
            LLM_caller = LLM_Caller_for_One_Thread()
        diagnosis_top5 = LLM_caller.query_model_and_extract_label(
            model_str=model_name,
            prompt=prompt_diagnosis_from_record.format(record),
            system_prompt=None,
            role="Doctor",
            ensure_label="answer",
        )
        diagnosis = diagnosis_top5.split("\n")[0]
        return diagnosis
    except:
        return None


def diagnose_from_record_and_judge_correctness(
    record: str,
    correct_diagnosis: str,
    model_name: str,
    LLM_caller: LLM_Caller_for_One_Thread = None,
    other_info: str = "",
) -> str:
    if LLM_caller is None:
        LLM_caller = LLM_Caller_for_One_Thread()
    diagnosis = diagnose_from_record(
        record=record, model_name=model_name, LLM_caller=LLM_caller
    )
    correctness, level = judge_correct_diagnosis_level(
        correct_diagnosis=correct_diagnosis,
        diagnosis_to_be_judged=diagnosis,
        LLM_caller=LLM_caller,
        other_info=other_info,
        model_name=model_name,
    )
    return {"diagnosis": diagnosis, "correctness": correctness, "level": level}


def judge_topk(
    topk_diagnosis: list,
    correct_diagnosis: str,
    model_name: str,
    LLM_caller: LLM_Caller_for_One_Thread = None,
    other_info: str = "",
):
    correctness_list = []
    if LLM_caller is None:
        LLM_caller = LLM_Caller_for_One_Thread()
    for d in topk_diagnosis:
        if d is None:
            correctness_list.append(
                {"diagnosis": None, "correctness": False, "level": "不正确"}
            )
        else:
            correctness, level = judge_correct_diagnosis_level(
                correct_diagnosis=correct_diagnosis,
                diagnosis_to_be_judged=d,
                LLM_caller=LLM_caller,
                model_name=model_name,
                other_info=other_info,
            )
            correctness_list.append(
                {"diagnosis": d, "correctness": correctness, "level": level}
            )
    return correctness_list
