import sys

sys.path.append("../")
from utils.call_LLM import (
    LLM_Caller_for_One_Thread,
    get_context_id,
    cached_LLM_id_list,
)
from utils.tools import extract_label
from utils.io_func import write_json, read_json
import re, time
from datetime import datetime
from utils.data_loader import Scenario_OSCE_Loader, Scenario_OSCE
from pathlib import Path
import random
import json
from concurrent.futures import as_completed, ProcessPoolExecutor
from config_generator import args
from prompt_English import *

if args.Chinese:
    from prompt_Chinese import *

global_dict_from_sys_prompt_to_context_id = {}


def transform_messages_to_conversation_str(messages: list) -> str:
    """
    把messages拼接为对话字符串
    messages:
    [
        {
            "role": "user",
            "content": "...",
            "stage": "Inquiry",
        },
        {
            "role": "assistant",
            "content": "...",
            "stage": "Inquiry",
        },
        ...
    ]
    """
    conversation_str = ""
    for message in messages:
        if message["role"] == "assistant":
            conversation_str += f"Doctor: {message['content']}\n"
            continue
        # Patient, role=user
        if message["stage"] == "Inquiry":
            conversation_str += f"Patient: {message['content']}\n\n"
        elif message["stage"] == "Physical Exam":
            conversation_str += f"Results: {message['content']}\n\n"
        elif message["stage"] == "Auxiliary Exam":
            conversation_str += f"Results: {message['content']}\n\n"
    return conversation_str


def is_valid_messages(messages):
    # 检查合理性，user assistant交替出现
    last_role = "assistant"
    for item in messages:
        role = item["role"]
        if role == last_role:
            return False
        last_role = role
        if "content" not in item:
            return False
    return True


def judge_correct_diagnosis_level(
    correct_diagnosis: str,
    diagnosis_to_be_judged: str,
    LLM_caller: LLM_Caller_for_One_Thread,
    other_info=None,
    model_name: str = "GLZ-Z1-Flash",
    level_list: list = ["完全正确", "临床可接受", "部分正确", "不正确"],
) -> tuple[bool, str]:
    if not diagnosis_to_be_judged or not correct_diagnosis:
        return False, level_list[-1]
    try:
        if not other_info:
            other_info = "空"
        prefix = (
            f"最终诊断 (绝对正确的标准答案)：\n{correct_diagnosis}\n\n"
            + f"医生诊断 (有待评估):\n{diagnosis_to_be_judged}\n\n"
            # + f"\n\n病历信息:\n{other_info}\n\n"
        )
        for _ in range(3):
            ans = LLM_caller.query_model_and_extract_label(
                model_str=model_name,
                prompt=USER_PROMPT_judge_diagnosis_correct_level.format(prefix),
                system_prompt=SYSTEM_PROMPT_judge_diagnosis_correct_level,
                role="Judge Diagnosis Level",
                ensure_label="answer",
            )
            if ans in level_list:
                correctness = ans in ["完全正确", "临床可接受"]
                return correctness, ans
    except Exception as e:
        print("Error in judge_correct_diagnosis_level:", e, flush=True)
    return False, level_list[-1]


def compare_results(
    diagnosis_to_be_verified, correct_diagnosis, LLM_caller: LLM_Caller_for_One_Thread
) -> bool:
    """evaluate whether the diagnosis is correct"""
    correctness, level = judge_correct_diagnosis_level(
        correct_diagnosis=correct_diagnosis,
        diagnosis_to_be_judged=diagnosis_to_be_verified,
        LLM_caller=LLM_caller,
        model_name=args.cheap_llm,
    )
    return correctness



def generate_auxiliary_exam(
    scenario: Scenario_OSCE,
    generated_messages: list,
    summary_inquiry_physical: str,
    LLM_caller: LLM_Caller_for_One_Thread,
    last_physical_action: str = "一般状况与生命体征（意识状态、生命体征、整体外观、身高体重）",
    step_by_step_ratio=0.0,
):
    # 体检到辅助检查：先简单诊断，然后判断需要转到辅检阶段；然后详细诊断，然后选择辅检项目
    first_auxiliary_message = {
        "role": "assistant",
        "content": "",
        "stage": "Auxiliary Exam",
    }
    conversation_summary = summary_inquiry_physical
    if True:  # 体检 -> 辅检
        first_auxiliary_message["summary"] = conversation_summary

        next_action = "终止体格检查，开始进入辅助检查阶段。"
        action_reasoning = ""
        first_auxiliary_message["action_reasoning_inquiry+physical"] = action_reasoning
        first_auxiliary_message["last_action"] = last_physical_action

    num_inquiry_physical = len(generated_messages)
    conversation_summary += "\n辅助检查结果：\n"
    auxiliary_dialog_str = f"{conversation_summary}"
    for auxiliary_cnt in range(args.auxiliary_exam_num):
        is_first_auxiliary = (
            auxiliary_cnt * 2 + num_inquiry_physical == num_inquiry_physical
        )
        message = {
            "role": "assistant",
            "content": "",
            "stage": "Auxiliary Exam",
        }
        if is_first_auxiliary:
            message = first_auxiliary_message
        else:
            # 总结新一轮辅助检查的QA
            conversation_history = transform_messages_to_conversation_str(
                generated_messages[-2:]
            )
            try:
                temp_summary = LLM_caller.query_model_and_extract_label(
                    model_str=args.cheap_llm,
                    prompt=prompt_summary.format(conversation_history),
                    role="summary",
                    ensure_label="answer",
                )
            except:
                temp_summary = generated_messages[-1]["content"]
            conversation_summary += temp_summary
            message["summary"] = conversation_summary
            if random.uniform(0, 1) > 0.5:
                message["summary"] = auxiliary_dialog_str

        # 鉴别诊断
        for _ in range(3):
            long_diagnosis_reasoning_content, long_diagnosis_answer = (
                LLM_caller.query_model_with_reasoning(
                    model_str=args.expensive_llm,
                    prompt=prompt_long_diagnosis.format(message["summary"]),
                    role="long differential diagnosis",
                    ensure_label=None,
                )
            )
            if (
                long_diagnosis_answer.strip().startswith("- ")
                and long_diagnosis_reasoning_content
            ):
                break
        if not (
            long_diagnosis_answer.strip().startswith("- ")
            and long_diagnosis_reasoning_content
        ):
            raise ValueError("generate_auxiliary_exam 生成鉴别诊断失败")
        message["long_diagnosis"] = [
            long_diagnosis_reasoning_content,
            long_diagnosis_answer,
        ]
        diagnosis_list = [
            d.strip() for d in long_diagnosis_answer.split("\n") if d.strip()
        ]
        if len(diagnosis_list) == 0:
            raise ValueError("generate_auxiliary_exam 提取topk诊断失败")

        verified_diagnosis_str = "\n".join(diagnosis_list)

        message["verified_diagnosis"] = verified_diagnosis_str

        # 先判断是否继续检查；如果选择停止检查、给出诊断，那么必须诊断正确才保留构造数据
        valid_answer = False
        for _try_cnt in range(2):
            decision_reason, decision_answer = LLM_caller.query_model_with_reasoning(
                model_str=args.expensive_llm,
                prompt=prompt_exam_or_diagnosis.format(
                    message["summary"], verified_diagnosis_str
                ),
                role="decide exam or diagnosis",
                ensure_label=None,
            )
            if not decision_answer == "继续辅助检查" and not (
                decision_answer.startswith("您的诊断结果为：")
                and "您的完整诊断如下：" in decision_answer
            ):  # 格式不对
                continue
            if decision_answer.startswith("您的诊断结果为：") and compare_results(
                diagnosis_to_be_verified=decision_answer,
                correct_diagnosis=scenario.diagnosis_information(),
                LLM_caller=LLM_caller,
            ):  # 选择诊断，并且诊断正确
                valid_answer = True
                break
            if decision_answer == "继续辅助检查":
                valid_answer = True
                break
        if not valid_answer:
            print("跳过判断是否继续检查")
            continue_exam = True
        else:
            message["decision_exam_or_diagnosis"] = [decision_reason, decision_answer]
            continue_exam = decision_answer == "继续辅助检查"
        if not continue_exam:  # 诊断
            message["content"] = decision_answer
            message["stage"] = "Diagnosis"
            generated_messages.append(message)
            return generated_messages
        else:  # 进行辅助检查
            for _ in range(3):
                auxiliary_exam_reasoning_content, auxiliary_exam_answer = (
                    LLM_caller.query_model_with_reasoning(
                        model_str=args.expensive_llm,
                        prompt=prompt_recommend_auxiliary_exams.format(
                            message["summary"], verified_diagnosis_str
                        ),
                        role="recommend auxiliary exams",
                        ensure_label=None,
                    )
                )
                if (
                    auxiliary_exam_answer.startswith("请求进行以下辅助检查：")
                    and auxiliary_exam_reasoning_content
                ):
                    break
            if not (
                auxiliary_exam_answer.startswith("请求进行以下辅助检查：")
                and auxiliary_exam_reasoning_content
            ):
                raise ValueError("生成下一步辅助检查 失败")
            message["content"] = auxiliary_exam_answer
            message["auxiliary_exam_reasoning"] = [
                auxiliary_exam_reasoning_content,
                auxiliary_exam_answer,
            ]
            generated_messages.append(message)
            # 添加病人回复：检查结果
            results_auxiliary_exam = LLM_caller.query_model_and_extract_label(
                model_str=args.cheap_llm,
                prompt=prompt_makeup_exam_results.format(
                    scenario.full_record() + scenario.diagnosis_information(),
                    auxiliary_exam_answer,
                ),
                role="makeup exam results",
                ensure_label="answer",
            )
            generated_messages.append(
                {
                    "role": "user",
                    "content": results_auxiliary_exam,
                    "stage": "Auxiliary Exam",
                }
            )
            auxiliary_dialog_str += f"Doctor：{auxiliary_exam_answer}\nResults：{results_auxiliary_exam}\n\n"
    return generated_messages


def run_one_scenario(_scenario_id: int, scenario: Scenario_OSCE, full_ratio=0.25):
    LLM_caller = LLM_Caller_for_One_Thread(
        introduction_log=f"\n<hr>\n\n## {_scenario_id+1}\n",
        dict_from_sys_prompt_to_context_id=global_dict_from_sys_prompt_to_context_id,
    )
    results_dict = None
    try:
        if not scenario.cover_3_stages():
            raise ValueError("病历质量差，未覆盖3个阶段")
        skip_inquiry = True

        if skip_inquiry:
            record = "病史信息：\n" + scenario.patient_information() + "\n\n"
            record += "体格检查结果：\n" + scenario.physical_exam_information()
            if random.uniform(0, 1) < 0.8:
                summary_inquiry_physical = record
            else:
                summary_inquiry_physical = LLM_caller.query_model_and_extract_label(
                    model_str=args.cheap_llm,
                    prompt=prompt_summary_record.format(record),
                    role="summary from record",
                    ensure_label="answer",
                )
            generated_messages_final = generate_auxiliary_exam(
                scenario,
                [],
                summary_inquiry_physical=summary_inquiry_physical,
                LLM_caller=LLM_caller,
            )
        results_dict = {
            "success": True,
            "scenario_id": _scenario_id,
            "correct_diagnosis": scenario.diagnosis_information(),
            "conversation_str": transform_messages_to_conversation_str(
                generated_messages_final
            ),
            "generated_messages": generated_messages_final,
            "OSCE_Examination": scenario.scenario_dict["OSCE_Examination"],
            # "detailed_log": LLM_caller.LLM_log_list,
        }
    except Exception as e:
        print("Error:", _scenario_id, e, flush=True)
        results_dict = {
            "success": False,
            "scenario_id": _scenario_id,
            "conversation_str": "",
            # "detailed_log": LLM_caller.LLM_log_list,
        }
    return results_dict, _scenario_id


def main():
    scenario_loader = Scenario_OSCE_Loader(args.dataset_path)
    results_list = []
    # 中断后，继续生成
    if args.continue_file_path:
        results_list = read_json(args.continue_file_path)["results_list"]
        print(f"Continue from {args.continue_file_path}", len(results_list))
    scenario_cnt = {}
    for res in results_list:
        if res.get("success", False):
            scenario_cnt[res["scenario_id"]] = (
                scenario_cnt.get(res["scenario_id"], 0) + 1
            )

    if args.num_scenarios is None:
        args.num_scenarios = scenario_loader.num_scenarios
    total_num_scenarios = min(args.num_scenarios, scenario_loader.num_scenarios)

    # 创建进程池并提交任务
    with ProcessPoolExecutor(max_workers=args.parallel_thread_num) as executor:
        future_to_scenario = {}
        for _scenario_id in range(total_num_scenarios):
            # 一个病历，重复多次
            for repeat_times in range(
                scenario_cnt.get(_scenario_id, 0), args.repeat_num
            ):
                future = executor.submit(
                    run_one_scenario,
                    _scenario_id,
                    scenario_loader.get_scenario(_scenario_id),
                )
                future_to_scenario[future] = _scenario_id * 100 + repeat_times

        # 处理完成的每个进程
        for future in as_completed(future_to_scenario):
            # 获取结果
            info_dict, _scenario_id = future.result()
            if info_dict.get("success", False):
                print(f"Scene {_scenario_id+1}: done", flush=True)
                results_list.append(info_dict)
                if len(results_list) % 100 == 0:
                    write_json(
                        {"settings": args.to_dict(), "results_list": results_list},
                        f"temp/{train_id}_{len(results_list)}.json",
                    )
            else:
                print(f"Scene {_scenario_id+1}: skipped")

    results_list = sorted(results_list, key=lambda x: x["scenario_id"])

    for item in results_list:
        print(f"### {item['scenario_id']+1}")
        print(item["conversation_str"])
        del item["conversation_str"]
        # if "detailed_log" in item:
        #     del item["detailed_log"]

    # 平均对话长度
    for res in results_list:
        res["conversation_length"] = len(res.get("generated_messages", []))
    write_json(
        {"settings": args.to_dict(), "results_list": results_list},
        f"detailed_log/{train_id}.json",
    )


def prepare_cache_context_id():
    """
    考虑到prompt cache
    """
    llm_list = [args.cheap_llm, args.expensive_llm]
    for llm in llm_list:
        if llm not in cached_LLM_id_list:
            continue
        cache_system_prompt_list = []
        for pormpt in cache_system_prompt_list:
            global_dict_from_sys_prompt_to_context_id[pormpt] = get_context_id(
                sys_prompt=pormpt,
                model_id=llm,
            )


if __name__ == "__main__":
    prepare_cache_context_id()
    current_time = datetime.now().strftime("%Y%m%d-%H%M")
    dataset_id = (
        "-".join([Path(p).stem for p in args.dataset_path])
        .replace("-fill_physical", "")
        .replace("filter-", "")
        .replace("OSCE_", "")
    )
    train_id = (
        f"{current_time}-{dataset_id}--doctor={str(args.expensive_llm).split('/')[-1]}"
    )

    print(train_id)
    main()
