from utils.myutils import AIClient
import json
import concurrent.futures
from tqdm import tqdm
from utils.myutils import errorMessage

class InfoQualityEvaluator:
    def __init__(self, config, save_folder):
        self.config = config
        self.model = (
            config["iq_judge_model"]
            if 'iq_judge_model' in config
            else config["judge_model"]
        )
        self.client_judge = AIClient(model=self.model)
        self.save_folder = save_folder

    def evaluate_case(self, entry):
        atom_info_list = entry.get('atom_information', [])
        interactions = entry.get('interactions', [])
        questions = [q for q, _ in interactions]
        patient_answers = " ".join(
            [a for _, a in interactions if 'not sure' not in a.lower() and 'unknown' not in a.lower() and 'don\'t know' not in a.lower()]
        )

        coverage_score = self.calc_coverage(patient_answers, atom_info_list)
        relevance_score = self.calc_relevance(questions, atom_info_list)

        return {
            "case_id": entry.get("case_id"),
            "coverage": coverage_score["coverage"],
            "coverage_detail": {
                "unit_num": len(atom_info_list),
                "covered_num": sum(1 for j in coverage_score["judgments"] if j.get("covered")),
                "judgments": coverage_score["judgments"],
                "retry_count": coverage_score["retry_count"],
            },
            "relevance": relevance_score["relevance"],
            "relevance_detail": {
                "q_num": len(questions),
                "relevant_num": sum(1 for j in relevance_score["judgments"] if j.get("relevant")),
                "judgments": relevance_score["judgments"],
                "retry_count": relevance_score["retry_count"],
            },
        }

    def evaluate(self, log_path):
        with open(log_path, 'r', encoding='utf-8') as f:
            log_data = [json.loads(line) for line in f if line.strip()]

        details = []
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(self.evaluate_case, entry) for entry in log_data]
            for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Info Quality Evaluation"):
                details.append(future.result())

        avg_coverage = sum(d["coverage"] for d in details) / len(details) if details else 0.0
        avg_relevance = sum(d["relevance"] for d in details) / len(details) if details else 0.0

        return {
            "avg_scores": {
                "coverage": round(avg_coverage, 3),
                "relevance": round(avg_relevance, 3),
            },
            "details": details,
        }

    def calc_coverage(self, patient_answers: str, atom_info_list: list[str], max_retries=3) -> dict:
        import re

        if len(atom_info_list) > 10:
            mid = len(atom_info_list) // 2
            part1 = self.calc_coverage(patient_answers, atom_info_list[:mid], max_retries)
            part2 = self.calc_coverage(patient_answers, atom_info_list[mid:], max_retries)
            merged_judgments = part1["judgments"] + part2["judgments"]
            coverage = sum(1 for j in merged_judgments if j.get("covered")) / len(merged_judgments) if merged_judgments else 0.0
            retry_count = part1["retry_count"] + part2["retry_count"]

            return {"coverage": coverage, "judgments": merged_judgments, "retry_count": retry_count}

        system_prompt = (
            "You are a careful medical evaluator. Your task is to check whether a series of patient answers"
            "cover or explicitly ask about each atomic information unit.\n\n"
            "Rules:\n"
            "1. Mark as [yes] if the unit is clearly being asked about, even if phrasing differs.\n"
            "2. Mark as [no] only if the doctor’s questions never address the unit.\n"
            "3. Do not be overly strict with wording; focus on semantic coverage. "
            "4. If the patient expresses uncertainty (e.g., 'not sure', 'unknown'), treat it as not covering the unit.\n"
            "5. Only [yes] or [no] is allowed, no other options.\n\n"
            "Output ONLY in this format inside <ANSWER> tags:\n"
            "<ANSWER>\nU1: [yes/no]\nU2: [yes/no]\n...</ANSWER>\n\n"
            "Do not add any explanation, comments, or extra text after [yes/no]."
        )


        user_prompt = (
            f"Patient’s answer:\n{patient_answers}\n\n"
            f"Atomic Information Units:\n"
            f"{chr(10).join(f'U{i+1}:{item}' for i, item in enumerate(atom_info_list))}\n\n"
            f"Check each unit for coverage and respond in the specified format. "
            f"Include {len(atom_info_list)} lines in the <ANSWER> block."
        )

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]

        for attempt in range(max_retries + 1):
            try:
                response = self.client_judge.get_response(messages)
                if not response:
                    continue
                match = re.search(r"<\s*ANSWER\s*>(.*?)<\s*/\s*ANSWER\s*>", response, re.DOTALL | re.IGNORECASE)
                if not match:
                    errorMessage("⚠️ No <ANSWER> block found.", self.save_folder + '/iq_errors.log')
                    continue
                answer_block = match.group(1).strip()
                                
                pattern = re.compile(r"\[(yes|no)\]", re.IGNORECASE)
                matches = pattern.findall(answer_block)

                if len(matches) != len(atom_info_list):
                    errorMessage(f"⚠️ Mismatch in number of judgments: expected {len(atom_info_list)}, got {len(matches)}",
                                self.save_folder + '/iq_errors.log')
                    errorMessage(answer_block, self.save_folder + '/iq_errors.log')
                    continue

                parsed = []
                for unit, ans in zip(atom_info_list, matches):
                    parsed.append({
                        "unit": unit,
                        "covered": ans.lower() == "yes"
                    })

                coverage = sum(1 for j in parsed if j["covered"]) / len(parsed)
                return {"coverage": coverage, "judgments": parsed, "retry_count": attempt}
            except Exception:
                continue
            
        return {"coverage": 0.0, "judgments": [], "retry_count": max_retries + 1}


    def calc_relevance(self, questions: list[str], atom_info_list: list[str], max_retries=3) -> dict:
        import re

        if len(atom_info_list) > 10:
            mid = len(atom_info_list) // 2
            part1 = self.calc_relevance(questions, atom_info_list[:mid], max_retries)
            part2 = self.calc_relevance(questions, atom_info_list[mid:], max_retries)

            merged_judgments = []
            for _, q in enumerate(questions):
                relevant = any(
                    jdg.get("question") == q and jdg.get("relevant")
                    for part in [part1, part2]
                    for jdg in part.get("judgments", [])
                )
                merged_judgments.append({"question": q, "relevant": relevant})

            relevance = sum(1 for j in merged_judgments if j["relevant"]) / len(merged_judgments) if merged_judgments else 0.0
            retry_count = part1["retry_count"] + part2["retry_count"]
            return {"relevance": relevance, "judgments": merged_judgments, "retry_count": retry_count}

        system_prompt = (
            "You are a careful medical evaluator. Your task is to check whether each doctor’s question "
            "is clearly aimed at retrieving at least one atomic information item.\n"
            "Rules:\n"
            "1. Mark [yes] if the question is meaningfully related to any item, even if implicit.\n"
            "2. Mark [no] only if the question is irrelevant or unrelated.\n"
            "3. Focus on intent, not superficial keyword overlap.\n"
            "4. Only [yes] or [no] is allowed, no other options.\n\n"
            "Output ONLY in this format inside <ANSWER> tags:\n"
            "<ANSWER>\nQ1: [yes/no]\nQ2: [yes/no]\n...</ANSWER>\n"
            "Do not add any explanation, comments, or extra text after [yes/no]."
        )

        user_prompt = (
            f"Atomic information items:\n"
            f"{chr(10).join(f'{i+1}. {item}' for i, item in enumerate(atom_info_list))}\n\n"
            f"Questions:\n"
            f"{chr(10).join(f'### Q{i+1}:\n"{q}"\n' for i, q in enumerate(questions))}"
        )

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]

        for attempt in range(max_retries + 1):
            try:
                response = self.client_judge.get_response(messages)
                if not response:
                    continue
                match = re.search(r"<\s*ANSWER\s*>(.*?)<\s*/\s*ANSWER\s*>",
                                response, re.DOTALL | re.IGNORECASE)
                if not match:
                    errorMessage("⚠️ No <ANSWER> block found.", self.save_folder + '/iq_errors.log')
                    continue

                answer_block = match.group(1).strip()

                pattern = re.compile(r"\[(yes|no)\]", re.IGNORECASE)
                matches = pattern.findall(answer_block)

                if len(matches) != len(questions):
                    errorMessage(f"⚠️ Mismatch in number of judgments: expected {len(questions)}, got {len(matches)}",
                                self.save_folder + '/iq_errors.log')
                    errorMessage(json.dumps(questions, indent=2) + '\n' + answer_block, self.save_folder + '/iq_errors.log')
                    continue

                parsed = []
                for q, ans in zip(questions, matches):
                    parsed.append({
                        "question": q,
                        "relevant": ans.lower() == "yes"
                    })

                relevance = sum(1 for j in parsed if j["relevant"]) / len(parsed)
                return {"relevance": relevance, "judgments": parsed, "retry_count": attempt}

            except Exception:
                continue

        return {"relevance": 0.0, "judgments": [], "retry_count": max_retries + 1}

