import numpy as np
from utils.myutils import AIClient, errorMessage
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

EVAL_SPECS = {
    "adherence": {
        "system_prompt": (
            "You are evaluating the ADHERENCE of a doctor’s questioning sequence in a doctor-patient interaction.\n\n"
            "Evaluation target:\n"
            "- Focus ONLY on the doctor’s questions (the patient’s answers are context to help you judge, not the evaluation subject).\n\n"
            "Adherence means:\n"
            "- The doctor always acts in the role of a professional doctor.\n"
            "- No breaking character (e.g., 'I am an AI', disclaimers, meta-statements).\n"
            "- No mechanical listing of multiple unrelated questions in one turn.\n"
            "- The doctor should ONLY ask questions to elicit information.\n"
            "- The doctor must NOT provide diagnoses, medical reasoning, treatment advices, or conclusions.\n"
            "- All questions must be natural, professional, relevant, and role-consistent.\n"
            "- Whether questions are repeated is NOT part of adherence (that is evaluated separately under coherence).\n\n"
            "Scoring guidance (per session):\n"
            "- [5/5]: Fully role-consistent, professional doctor style maintained throughout.\n"
            "- [4/5]: Minor issues (slightly verbose, or occasional awkward phrasing).\n"
            "- [3/5]: Some issues (AI-like wording, mechanical phrasing, or occasional irrelevant questions).\n"
            "- [2/5]: Frequent breaking of role, frequent mechanical listing, or multiple irrelevant questions.\n"
            "- [1/5]: Clear violation: AI self-disclosure, giving diagnoses/medical reasoning, or repeated meta-behaviors.\n\n"
            "Evaluate the series of questions as a whole. Only consider the doctor's questions. The patient's responses are context only. "
            "Provide a single numeric score [1-5] wrapped in square brackets, and a brief explanation.\n\n"
            "Format:\n"
            "<ANSWER>\n"
            "[score/5] # Explanation for the Score\n"
            "</ANSWER>"
        )
    },
    "coherence": {
        "system_prompt": (
            "You are evaluating the COHERENCE of a doctor’s questioning sequence in a doctor-patient dialogue.\n\n"
            "Evaluation target:\n"
            "- Focus on the doctor’s questions as a sequence.\n"
            "- Patient answers are used only as context to judge whether the doctor’s questions are coherent, not as the evaluation subject.\n\n"
            "Coherence means:\n"
            "- Questions should follow logically across the sequence.\n"
            "- No contradictions with what the patient has already answered.\n"
            "- No repeated questions (whether exact or paraphrased) that seek information the patient has already clearly provided.\n"
            "- Smooth transitions, natural flow, consistent with patient’s context.\n\n"
            "Scoring guidance (per session):\n"
            "- [5/5]: Questions flow naturally, no unnecessary repetition, smooth logical progression.\n"
            "- [4/5]: Mostly coherent, with minor redundancy or slightly awkward flow.\n"
            "- [3/5]: Some issues (e.g., noticeable repetition, weak logical links between questions).\n"
            "- [2/5]: Frequent repetition or disjointed question flow.\n"
            "- [1/5]: Severe incoherence: many repeated or contradictory questions, very poor flow.\n\n"
            "Evaluate the series of questions as a whole. Only consider the doctor's questions. The patient's responses are context only. "
            "Provide a single numeric score [1-5] wrapped in square brackets, and a brief explanation.\n\n"
            "Format:\n"
            "<ANSWER>\n"
            "[score/5] # Explanation for the Score\n"
            "</ANSWER>"
        )
    }
}

class DialogManageEvaluator:
    def __init__(self, config, save_folder):
        self.config = config
        self.model = (
            config["dm_judge_model"]
            if 'dm_judge_model' in config
            else config["judge_model"]
        )
        self.client = AIClient(model=self.model)
        self.save_folder = save_folder

    def evaluate(self, log_path):
        with open(log_path, "r", encoding="utf-8") as f:
            entries = [json.loads(line) for line in f]

        adherence_list = []
        coherence_list = []
        details = []

        def process_entry(entry, idx):
            interactions = entry.get("interactions", [])
            adherence_scores = self._eval_dimension(interactions, "adherence")
            coherence_score = self._eval_dimension(interactions, "coherence")
            result = {
                "case_id": entry.get("case_id", f"row_{idx}"),
                "adherence_score": adherence_scores.get("score", "") / 5 if adherence_scores else 0.0,
                "adherence_details": adherence_scores,
                "coherence_score": coherence_score.get("score", "") / 5 if coherence_score else 0.0,
                "coherence_explanation": coherence_score.get("explanation", "") if coherence_score else ""
            }
            return result

        with ThreadPoolExecutor(max_workers=8) as executor:
            futures = [executor.submit(process_entry, entry, i) for i, entry in enumerate(entries)]
            
            for fut in tqdm(as_completed(futures), total=len(futures), desc="Dialog Manage Evaluation"):
                try:
                    res = fut.result()
                    adherence_list.append(res["adherence_score"])
                    coherence_list.append(res["coherence_score"])
                    details.append(res)
                except Exception as e:
                    print(f"⚠️ Entry processing failed: {e}")

        return {
            "avg_scores": {
                "adherence": round(np.mean(adherence_list), 3) if adherence_list else 0.0,
                "coherence": round(np.mean(coherence_list), 3) if coherence_list else 0.0
            },
            "details": details
        }

    def _eval_dimension(self, interactions: list[tuple[str, str]], dim: str, max_retries=3):
        import re

        spec = EVAL_SPECS.get(dim)
        if not spec:
            raise ValueError(f"Unknown dimension: {dim}")

        system_prompt = spec["system_prompt"]

        dialogue_prompt = (
            f"Here are {len(interactions)} Q&A pairs from a doctor-patient interaction session.\n" +
            "\n".join(f"Q{i+1}: {q}\nA{i+1}: {a}" for i, (q, a) in enumerate(interactions)) +
            "\nPlease provide ONE score for the whole session."
        )

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": dialogue_prompt}
        ]

        for attempt in range(max_retries):
            try:
                response = self.client.get_response(messages)
                content = response["content"] if isinstance(response, dict) else response

                answer_block = ""
                match_answer = re.search(r"<ANSWER>(.*?)</ANSWER>", content, re.DOTALL | re.IGNORECASE)
                if match_answer:
                    answer_block = match_answer.group(1)
                else:
                    answer_block = content

                match = re.search(r"\[(\d)/5\](?:\s*#\s*(.*))?", answer_block, re.DOTALL)
                if match:
                    score = int(match.group(1))
                    explanation = match.group(2).strip() if match.group(2) else ""
                    return {"score": score, "explanation": explanation}
                else:
                    errorMessage(f"⚠️ [{dim}] Missing score, got: {content}", self.save_folder + "/dm_errors.log")
                    return {"score": 0, "explanation": ""}

            except Exception as e:
                errorMessage(f"{dim.capitalize()} scoring failed (attempt {attempt+1}): {e}", self.save_folder + "/dm_errors.log")

        errorMessage(f"[ERROR] {dim} evaluation failed after {max_retries} attempts.", self.save_folder + "/dm_errors.log")
        return 0