from .base import Doctor
from utils.myutils import doctorMessage

class DoctorCot(Doctor):
    def __init__(self, config, task, choices, finish_len=10):
        super().__init__(config, task, choices)
        self.finish_len = finish_len
        self.history = []
        self.max_rounds = int(self.config.get('doctor', 'max_questions'))
        self.additional_info = {"reasoning": []}

    def ask(self, answer):
        if answer is not None:
            self.history.append({"role": "user", "content": answer})
        else:
            self.history.append({"role": "user", "content": "The consultation begins."})
        
        system_prompt = (
            "You are a medical professional in medical inquiry. "
            "You ask questions to gather information for diagnosis."
            f"Possible questions include demographics, symptoms, medical history, family "
            f"history, physical exam findings, lab results, and lifestyle. "
            
            # CoT prompt
            "Think step by step. You should first write your reasoning about "
            f"whether you have enough information, and if not, what information "
            "is still needed to make a diagnosis. "
            "Based on this reasoning, you will either ask one question "
            "or decide to end the consultation. After your reasoning, "
            "output your question (or the decision to end) after a "
            "'### Question' delimiter."
            
            "If you think the collected information is sufficient for a diagnosis "
            f"among the following choices: {', '.join(self.choices)}, "
            f"respond with \"end consultation\" to stop the inquiry. "
            f"If not, ask **ONE** question a turn, balancing information acquisition, "
            "dialogue quality, inquiry efficiency, and patient experience. "
            "Never provide diagnosis or treatment suggestions. "
        )
    
        messages = [
            {"role": "system", "content": system_prompt}
        ] + self.history
        
        # Get question after "### Question"
        response = self.client.get_response(messages=messages)
        reasoning = response.split("### Question")[0].strip()
        new_question = response.split("### Question")[-1].strip()
        
        doctorMessage(f"Reasoning: {reasoning}")
        doctorMessage(f"New Question: {new_question}")
        
        self.additional_info["reasoning"].append({"round": len(self.history)//2 + 1, "reasoning": reasoning})
        
        if 'end consultation' in new_question.lower() or new_question.strip() == '':
            return None
        elif len(self.history) >= self.finish_len * 2:
            return "Exceed the maximum number of rounds"
        
        self.history.append({"role": "assistant", "content": new_question})
        return new_question