from core.infra.llm import LLMApi
from core.model.prompt import Prompt
from core.alg.evaluator import Evaluator
import random


class Mutator:
    def __init__(self, problem_description, population, is_f1=False):
        self.problem_description = problem_description
        self.population = population
        self.is_f1 = is_f1
        self.system_message = '''
            # Role
            You are an extremely professional LLM prompt engineer who is proficient in using numerous prompt optimization techniques, such as Chain of Thought (COT), Tree of Thought (TOT), Reflection, and so on. 
            # Task
            Please optimize the prompt based on the provided instructions. Output only the generated prompt in the end. 
        '''
        self.MAX_ITERS = 3  # 迭代轮次上限
        self.TOP_FEWSHOT_K = 6  # 最终 few-shot 数量上限
        self.MIN_INPUT_LEN = 6  # 过滤极短输入，避免噪声 few-shot

    def badcase_tracker_generation(self, badcase_tracker, top_badcase_k=5):
        """
        Generate a new prompt by analyzing high-frequency badcases and their associated failed prompts.
        """
        # 1. Get top-k most frequent badcases
        badcase_dict = badcase_tracker.get_badcases()
        badcase_list = [
            (input_text, info['count'], info['prompts'])
            for input_text, info in badcase_dict.items()
        ]
        badcase_list = sorted(badcase_list, key=lambda x: x[1], reverse=True)[:top_badcase_k]
        if not badcase_list:
            # No valid badcases, fallback to direct mutation of the best prompt in population
            return self.direct_mutation(self.population[0])

        # 2. Build badcase info block
        badcase_block = "\n\n".join([
            f"Input: {input_text}\nErrorTimes: {count}\nFailedPrompts:\n" +
            "\n".join(f"- {prompt_text}" for prompt_text in prompts)
            for input_text, count, prompts in badcase_list
        ])

        # 3. Prepare LLM instruction
        user_content = f"""
You are an expert LLM prompt engineer. Your task is to review the most frequently misclassified inputs along with the prompts that failed to handle them, and design a new, more robust prompt for the given task. Please analyze the typical errors and shortcomings of the failed prompts, and combine advanced prompt engineering techniques (such as Chain of Thought, Reflection, preventive constraints, etc.) to maximize the prompt's reliability for similar cases.
Output ONLY the improved prompt in English, without explanations.

# Task Description:
{self.problem_description}

# High-frequency Badcase Analysis (error counts and failed prompts):
{badcase_block}
"""

        messages = [
            {"role": "system", "content": self.system_message},
            {"role": "user", "content": user_content}
        ]
        llm = LLMApi("azure_gpt")
        new_prompt_text = llm.generate(messages)
        if not new_prompt_text:  # Fallback
            new_prompt_text = self.population[0].text

        return Prompt(new_prompt_text, 0, 0, [], [], [])

    def direct_mutation(self, prompt):
        llm_api = LLMApi("azure_gpt")
        user_content = f"Please rephrase the following instruction in an alternative manner while retaining its original meaning: {prompt.text}"
        # FIXME f1上面这行是后面这坨，是否有必要区分？| user_content = f"Rephrase this instruction without changing its meaning: {prompt.text}"
        messages = [
            {"role": "system", "content": self.system_message},
            {"role": "user", "content": user_content}
        ]
        new_prompt_text = llm_api.generate(messages)
        if new_prompt_text is None:
            new_prompt_text = prompt.text
        return Prompt(new_prompt_text, 0, 0, [], ancestors=[prompt], bad_cases=[])

    def zero_order_generation(self, prompt):
        llm_api = LLMApi("azure_gpt")
        # FIXME f1上面这行是后面这坨，是否有必要区分？| llm_api = LLMApi("azure_gpt_thinking")
        prompt_texts = [p.text for p in self.population]
        combined_prompts = '\n'.join(prompt_texts)
        user_content = f'''
            The optimization of the current effect has reached a standstill. 
            Please analyze the characteristics of all existing prompts. Generate an entirely new prompt with a more innovative structure and a richer set of techniques. Note that more prompt techniques should be employed.
            The following are all the current prompts:{combined_prompts}
        '''
        messages = [
            {"role": "system", "content": self.system_message},
            {"role": "user", "content": user_content}
        ]
        new_prompt_text = llm_api.generate(messages)
        if new_prompt_text is None:
            new_prompt_text = prompt.text
        return Prompt(new_prompt_text, 0, 0, [], [], [])

    def badcase_reflection(self, prompt):
        """
        不做样本总结/筛选：
        - 直接用全量 badcases 生成自我反思提示词
        - 用(反思 + 当前系统提示词 + 错误样本)更新系统提示词
        - 在同一批 badcases 上复评 + 计分（第1次错+1，第2次再错+2，...）
        - 迭代最多3轮或提前收敛
        - 以高分样本作为 few-shots 拼接到最终系统提示词末尾
        """
        if not prompt.bad_cases:
            return self.direct_mutation(prompt)

        llm = LLMApi("azure_gpt")  # FIXME 这行没用到
        current_system_prompt = prompt.text

        fail_scores = {}
        all_badcases = prompt.bad_cases

        badcase_sampling = 'fixed'
        # ==【抽样逻辑】==
        if badcase_sampling == 'all':
            full_cases = all_badcases
        elif badcase_sampling == 'ratio':
            sample_ratio = 0.2  # 抽样比例
            n_sample = max(1, int(len(all_badcases) * sample_ratio))
            full_cases = random.sample(all_badcases, n_sample)
        elif badcase_sampling == 'fixed':
            sample_num = 15  # 固定抽样数量
            n_sample = min(sample_num, len(all_badcases))
            full_cases = random.sample(all_badcases, n_sample)
        else:
            # 非法参数兜底
            full_cases = all_badcases

        # —— 多轮闭环 ——
        for it in range(1, self.MAX_ITERS + 1):
            if not full_cases:
                break

            # 1) 生成 MAPS 风格自我反思提示词（基于全量样本）
            reflection_prompt = self._build_reflection_meta_prompt__ALL(full_cases, current_system_prompt)

            # 2) 用(反思 + 当前系统 + 错误样本) 产出“改进后的系统提示词”
            improved_system = self._update_system_prompt(reflection_prompt, full_cases, current_system_prompt)
            if improved_system:
                current_system_prompt = improved_system  # 覆盖
            # 失败兜底：沿用旧的 current_system_prompt

            still_failing = self._evaluate_on_cases(current_system_prompt, full_cases)
            if not still_failing:
                break  # 本轮全部通过，提前收敛

            for (inp, _exp) in still_failing:
                fail_scores[inp] = fail_scores.get(inp, 0) + it

            # 下一轮仍然使用 full_cases（你的新设想：不做筛选，继续在全量上训练）
            # 如想更快收敛，也可改为：
            full_cases = still_failing

        # 4) 生成 few-shots：按 fail_scores 选 Top-K 高分样本
        few_shot_block = self._build_fewshot_block__ALL(
            all_cases=full_cases,
            score_map=fail_scores,
            top_k=self.TOP_FEWSHOT_K,
            min_len=self.MIN_INPUT_LEN
        )

        insert_marker = "## Prediction"

        if few_shot_block:
            if insert_marker in current_system_prompt:
                before, after = current_system_prompt.split(insert_marker, 1)
                final_text = (
                        before
                        + "# Few-Shot Exemplars (from high-scoring failures)\n"
                        + few_shot_block
                        + "\n"
                        + insert_marker
                        + after  # 继续保留原来的预测部分
                )
            else:
                final_text = current_system_prompt + "\n\n# Few-Shot Exemplars (from high-scoring failures)\n" + few_shot_block
        else:
            final_text = current_system_prompt

        return Prompt(final_text, 0, 0, [], ancestors=[prompt], bad_cases=[])

    def _build_reflection_meta_prompt__ALL(self, all_cases, prompt):

        # 拼接错误样本为 readable block
        badcase_block = "\n".join([
            f"Input: {inp.strip()}\nExpected: {exp.strip()}"
            for inp, exp in all_cases if inp and exp
        ])

        # 任务描述（你给的标准模板）
        task_description = f'''
        Role
        You are an expert at adapting instructions for language models. Your task is to create a personalized Self-Reflection prompt for a model that is trying to solve a problem.

        Task Description
        You are working with the following system prompt currently in use by the model:
        ---------------------
        {prompt}
        ---------------------

        Your task is to modify the Self-Reflection template to make it as specific and helpful as possible for the given problem. Please pay special attention to the following aspects:
        • Type of problem: The Self-Reflection prompt should guide the model to think about and address the specific type of problem presented in the question.
        • Common mistakes: The Self-Reflection prompt should guide the model to identify common mistakes made when solving this type of problem.
        • Complexity of the problem: The Self-Reflection prompt should guide the model to understand the complexity of the problem and consider whether more steps are needed to solve it.

        Self-Reflection Template
        You are an expert in <PROBLEM AREA>.
        You have answered the following question incorrectly.
        Your task is to reflect on the problem, your solution, and the correct answer.
        Then, you will use this information to help you answer similar questions correctly in the future:

        Step 1: Explain why you answered the question incorrectly.
        Step 2: List keywords describing the types of your errors, ordered from general to specific.
        Step 3: Create a detailed set of instructions to help you answer this type of question correctly in the future.
        Step 4: List general advice for solving similar types of problems.

        Be concise in your response, but include all key information.

        [Example Output]
        All error examples:
        Explanation: The reasons for these errors are: <explanation of your mistakes>
        Error Keywords: <keywords related to the mistake>
        Instructions: <step-by-step instructions on how to solve such problems>
        Advice: <general advice for solving similar problems>
        Solution: <the correct reasoning and answer for the problem>

        The following are real error examples the model has made:
        ---------------------
        {badcase_block}
        ---------------------

        Final Task
        Now, adapt the above template based on the system prompt and the errors shown.
        Generate a new Self-Reflection prompt tailored for improving performance on similar problems.
        '''

        # 拼接最终输入内容
        final_prompt = (
                task_description +
                "\nAll failing examples:\n" +
                badcase_block +
                "\n\nGenerate the adapted Self-Reflection prompt."
        )

        # 调用大模型接口
        llm = LLMApi("azure_gpt")
        messages = [
            {"role": "system", "content": "You are a prompt optimization assistant."},
            {"role": "user", "content": final_prompt}
        ]
        reflection_prompt = llm.generate(messages)

        return reflection_prompt or ""  # 兜底：如果生成失败则返回空字符串（或可加 fallback）

    def _update_system_prompt(self, reflection, all_cases, prompt):
        """
        使用已有的 reflection、自带的 prompt 和 badcase 样本，调用 LLM 生成新的系统提示词。
        """
        llm = LLMApi("azure_gpt")
        # FIXME f1上面这行是后面这坨，是否有必要区分？| llm_api = LLMApi("azure_gpt_thinking")

        # 构造错误样例文本块
        cases_block = "\n".join([f"Input: {i}\nExpected: {e}" for i, e in all_cases])

        # 构造 LLM 用户指令，输入已有的提示词、反思内容和所有样例
        user = (
            "Update the SYSTEM PROMPT for a math/navigation reasoning model.\n"
            "Given:\n"
            "1) Current System Prompt\n"
            "2) Reflection (diagnosis, error keywords, rewrite principles, sketch)\n"
            "3) All failing cases (no sampling)\n\n"
            "Produce a SINGLE improved SYSTEM PROMPT. "
            "Return ONLY the new prompt text (no markdown, no comments).\n\n"
            "=== Current System Prompt ===\n"
            f"{prompt}\n\n"
            "=== Reflection ===\n"
            f"{reflection}\n\n"
            "=== Failing Cases ===\n"
            f"{cases_block}\n"
        )

        # 构造消息发送到 LLM
        messages = [
            {"role": "system", "content": self.system_message},
            {"role": "user", "content": user}
        ]

        # 调用模型生成新 prompt
        new_prompt_text = llm.generate(messages)

        # 如果失败则保留旧 prompt
        return new_prompt_text or prompt.text

    def _evaluate_on_cases(self, system_prompt, cases):
        """
        使用更新后的 system_prompt 在同一批 badcases 上复评；
        返回仍失败的子集 List[(input, expected)]。
        """
        llm = LLMApi()
        still_fail = []
        for (inp, exp) in cases:
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": inp}
            ]
            resp = llm.generate(messages)
            pred = Evaluator(self.is_f1).parse_response(resp)
            if str(pred).strip().upper() != str(exp).strip().upper():
                still_fail.append((inp, exp))
        return still_fail

    def _build_fewshot_block__ALL(self, all_cases, score_map, top_k=6, min_len=6):
        """
        从全量 badcases 中，按 score_map 的分数高低选 Top-K，生成简洁的 few-shot 文本块。
        仅保留长度足够的输入，避免噪声示例。
        """
        # 去重（按 input），保留首次出现的 expected
        seen = {}
        for i, e in all_cases:
            i_s = (i or "").strip()
            e_s = (e or "").strip()
            if len(i_s) < min_len:
                continue
            if i_s not in seen:
                seen[i_s] = e_s

        ranked = sorted(seen.items(), key=lambda kv: score_map.get(kv[0], 0), reverse=True)[:top_k]
        if not ranked:
            return ""

        lines = []
        for idx, (inp, exp) in enumerate(ranked, 1):
            lines.append(f"[Example {idx}]\nInput: {inp}\nExpected: {exp}\n")
        return "\n".join(lines)
