import collections

from .base import Doctor
from utils.myutils import doctorMessage

class DoctorRerank(Doctor):
    def __init__(self, config, task, choices, finish_len=10, num_samples=5):
        super().__init__(config, task, choices)
        self.finish_len = finish_len
        self.history = []
        self.max_rounds = int(self.config.get('doctor', 'max_questions'))
        self.additional_info = {"samples": []}
        self.num_samples = num_samples

    def ask(self, answer):
        if answer is not None:
            self.history.append({"role": "user", "content": answer})
        else:
            self.history.append({"role": "user", "content": "The consultation begins."})
        
        system_prompt = (
            "You are a medical professional in medical inquiry. "
            "You ask questions to gather information for diagnosis."
            f"Possible questions include demographics, symptoms, medical history, family "
            f"history, physical exam findings, lab results, and lifestyle. "
            "If you think the collected information is sufficient for a diagnosis "
            f"among the following choices: {', '.join(self.choices)}, "
            f"respond with \"end consultation\" to stop the inquiry. "
            f"If not, ask **ONE** question a turn, balancing information acquisition, "
            "dialogue quality, inquiry efficiency, and patient experience. "
            "Never provide diagnosis or treatment suggestions. "
        )

        messages = [
            {"role": "system", "content": system_prompt}
        ] + self.history

        # Sample multiple candidate questions
        candidate_questions = []
        doctorMessage("Sampling multiple questions...")
        for _ in range(self.num_samples):
            response = self.client.get_response(messages=messages)
            candidate_questions.append(response.strip())

        # Include the option to end the consultation
        if "end consultation" not in candidate_questions:
            candidate_questions.append("end consultation")
        
        self.additional_info["samples"].append(
            {"round": len(self.history)//2 + 1, "candidates": candidate_questions}
        )
        
        # Reranking     
        rerank_prompt = (
            "Given the following medical inquiry history:\n"
            f"History: {self.history}\n"
            "And a list of candidate actions for the next turn:\n"
            f"{'\\n'.join([f'- Candidate {i+1}: {q}' for i, q in enumerate(candidate_questions)])}\n"
            "Please select the best next action. You may either ask a new question or end the inquiry. "
            "If you decide to continue the dialogue, choose the question that is most logical, "
            "consistent with the history, and most likely to efficiently gather needed information "
            f"for a diagnosis among the following choices: {', '.join(self.choices)}, "
            "while also being clear and empathetic. "
            "Respond with the exact text of the chosen action, without any other explanation."
        )

        rerank_messages = [
            {"role": "user", "content": rerank_prompt}
        ]
        
        best_question = self.client.get_response(messages=rerank_messages).strip()

        # Fallback
        if best_question not in candidate_questions:
            counts = collections.Counter(candidate_questions)
            best_question = counts.most_common(1)[0][0]
            doctorMessage(f"Reranking failed to select a candidate. Falling back to most common: {best_question}")
        
        doctorMessage(f"Selected Question: {best_question}")
        new_question = best_question

        if 'end consultation' in new_question.lower() or new_question.strip() == '':
            return None
        elif len(self.history) >= self.finish_len * 2:
            return "Exceed the maximum number of rounds"
        
        self.history.append({"role": "assistant", "content": new_question})
        return new_question