# -*- coding: utf-8 -*-
"""
Defines prompts for the combined evolution of MOOSF modules (HICA, GHOP, EOSS).
Scheme C Modification: Instructs the LLM to generate code for all three components simultaneously,
outputting the result as a single string with specific delimiters.
--- Scheme C Option 2 Modification ---
Adding separate methods to generate prompts for mutating each module individually.
"""
import torch
import numpy as np
import os
# No longer need import json

# Define delimiters clearly (These are now accessed via the prompts object)
HICA_START_DELIMITER = "===HICA_CODE_START==="
HICA_END_DELIMITER = "===HICA_CODE_END==="
GHOP_START_DELIMITER = "===GHOP_CODE_START==="
GHOP_END_DELIMITER = "===GHOP_CODE_END==="
# --- REMOVED EOSS Delimiters ---


class PromptsCombined:
    # --- Add the delimiters as class attributes so they can be accessed ---
    HICA_START_DELIMITER = HICA_START_DELIMITER
    HICA_END_DELIMITER = HICA_END_DELIMITER
    GHOP_START_DELIMITER = GHOP_START_DELIMITER
    GHOP_END_DELIMITER = GHOP_END_DELIMITER
    # --- REMOVED EOSS delimiter attributes ---
    # --------------------------------------------------------------------

    def __init__(self, fixed_modules_code: dict | None = None):
        """Initialize prompt components and store fixed code."""
        self.fixed_modules_code = fixed_modules_code if fixed_modules_code else {}

    # --- Keep dummy methods to satisfy eoh internal calls ---
    def get_func_name(self):
        return "combined_moosf_evolution"

    def get_class_name(self):
        return "CombinedMOOSF"

    def get_func_inputs(self):
        return ["combined_inputs"]

    def get_func_outputs(self):
        return ["combined_outputs"]

    def get_inout_inf(self):
        return "Combined evolution: Input/output info spans multiple components."
    # ------------------------------------------------------

    def get_task(self):
        """Returns a string describing the core task and goal in the LOS context."""
        return (
            "Simultaneously evolve HICA (`Pla` class), GHOP (`_project_conflicting` method body), and EOSS (`Weight_acc` class) "
            "components originally from the MOOSF framework, now integrated within the **LOS framework**. "
            "The goal is to maximize a weighted accuracy metric (emphasizing Few-shot) on CIFAR-100 (imbalance 100), "
            "evaluated using the **LOS `main_stage2.py` script** which performs multi-head fine-tuning."
        )

    def get_component_details(self):
        # --- REVISED Component Details based on LOS/* code analysis ---
        return (
            "**Component Details (Targeting LOS Framework Implementation):**\\n"
            # --- HICA (Pla in MBACK.py) Description ---
            f"1. **HICA:** Provide full Python code string for `class Pla: ...` (including `import torch`, `import numpy as np`). Target: `LOS/MultiBackward/MBACK.py`.\\n"
            f"   **REQUIREMENTS (Based on `MBACK.py`):**\\n"
            f"   - `__init__(self, num_tasks)`: Initializes `self.beta` as a NumPy array of ones with size `num_tasks`.\n"
            f"   - `update(self, acc_s)`: Takes `acc_s` (list of **PyTorch Tensors**, where each tensor is an accuracy vector). \n"
            f"     Ensure all elements in `acc_s` are indeed PyTorch Tensors before use. Add a small epsilon (e.g., 1e-8) to denominators for numerical stability. \n"
            f"     The core calculation for similarity between `acc_s[i]` (a strategy's accuracy vector) and `acc_s[-1]` (reference accuracy vector) is typically `similarity_tensor = (acc_s[i] * acc_s[-1]).sum() / (acc_s[-1].norm(2)**2 + 1e-8)`. \n"
            f"     This `similarity_tensor` **MUST** be a single-element PyTorch Tensor. \n"
            f"     Update the NumPy array `self.beta` IN PLACE: `self.beta[i] = similarity_tensor.item()`. \n"
            f"     This `.item()` call is only valid if `similarity_tensor` is a single-element PyTorch Tensor. If `similarity_tensor` somehow results in a Python scalar directly, do not call `.item()`. Ensure the calculation consistently produces a single-element tensor. \n"
            f"     Handle potential type issues between tensor operations and NumPy array assignment carefully.\n"
            f"   - `pla(self, losses)`: Takes `losses` (list[tensor]). Modifies the list IN PLACE: `losses[i] = self.beta[i]*losses[i]`. Note: This relies on PyTorch handling multiplication between a NumPy scalar (`self.beta[i]`) and a PyTorch tensor (`losses[i]`). Return the modified `losses` list. Include asserts for beta values (e.g., `0 < abs(beta) < 1e8`).\n"
            f"   Place the **complete** `class Pla: ...` code between `{self.HICA_START_DELIMITER}` and `{self.HICA_END_DELIMITER}`.\\n"
            # --- End HICA Description ---

            # --- GHOP (_project_conflicting in pcgrad.py) Description ---
            f"2. **GHOP:** Provide Python code string for the **method body only** of `_project_conflicting(self, grads, has_grads, shapes=None)`. Include necessary imports (`import torch`, `import copy`, `import random`). Target: `LOS/MultiBackward/GradFun/pcgrad.py`.\\n"
            f"   **INPUTS (Based on `pcgrad.py`):** `grads` (list[flattened 1D tensor]), `has_grads` (list[flattened 1D boolean tensor]).\\n"
            f"   **IMPLEMENTATION (Based on `pcgrad.py`):**\\n"
            f"   a. Calculate shared mask: `shared = torch.stack(has_grads).prod(0).bool()`.\\n"
            f"   b. Copy gradients: `pc_grad = copy.deepcopy(grads)`.\\n"
            f"   c. Iterate `g_i` in `pc_grad`.\\n"
            f"   d. Inside, `random.shuffle(grads)` (original grads).\\n"
            f"   e. Inside, iterate `g_j` in shuffled `grads`.\\n"
            f"   f. Check conflict: `g_i_g_j = torch.dot(g_i, g_j)`. If `g_i_g_j < 0`, project `g_i`: `g_i -= g_i_g_j * g_j / (g_j.norm()**2)`. **Note: No epsilon added in the denominator in the reference code.**\\n"
            f"   g. Initialize `merged_grad = torch.zeros_like(grads[0]).to(grads[0].device)`.\\n"
            f"   h. Merge gradients based on `self._reduction`:\\n"
            f"      - If `'mean'`: `merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).mean(dim=0)`.\\n"
            f"      - Elif `'sum'`: `merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).sum(dim=0)`.\\n"
            f"      - Else: `exit('invalid reduction method')`.\\n"
            f"   i. Merge non-shared parts: `merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)`.\\n"
            f"   j. Final line **MUST** be `return merged_grad`. The `merged_grad` returned should be a **single, flattened PyTorch tensor**.\\n"
            f"   Place the **complete method body** between `{self.GHOP_START_DELIMITER}` and `{self.GHOP_END_DELIMITER}`.\\n"
            # --- End GHOP Description ---

            # --- EOSS Section REMOVED as it's now fixed ---
            # f"3. **EOSS:** ... "
            # --- End EOSS Description ---
        )

    def get_output_format_instructions(self):
        """Specifies the required delimited output format for HICA and GHOP only."""
        return (
            f"**CRITICAL OUTPUT FORMAT REQUIREMENT:**\\n"
            f"Your response **MUST** be a single block of text containing the code parts requested, potentially enclosed in delimiters.\\n"
            f"1. If HICA is evolved: Its code (full class) must be between `{self.HICA_START_DELIMITER}` and `{self.HICA_END_DELIMITER}`.\\n"
            f"2. If GHOP is evolved: Its code (method body) must be between `{self.GHOP_START_DELIMITER}` and `{self.GHOP_END_DELIMITER}`.\\n"
            # --- REMOVED EOSS Instruction ---
            # f"3. If EOSS ... "
            f"Ensure any required delimiters appear exactly as shown.\\n"
            f"Include necessary Python imports *only within the first code block* provided (HICA or GHOP).\\n"
            # --- REMOVED EOSS specific exclusions --- #
            f"**DO NOT** include any other text, explanations, or markdown formatting outside the requested code parts (and HICA/GHOP delimiters if used)."
        )

    def get_other_inf(self):
        """Provides additional context or constraints."""
        # Added general robustness constraints
        return (
            "- Evaluation uses CIFAR-100 (imbalance 100) within the **LOS framework (`main_stage2.py`)**.\n"
            "- Evaluation involves fine-tuning multiple classifier heads for a specified number of epochs (e.g., 50).\n"
            "- Fitness is a squared weighted accuracy (higher is better), emphasizing Few-shot performance, derived from validation results.\n"
            "- Ensure compatibility with PyTorch (e.g., 1.13+) and NumPy.\n"
            "- **General Code Quality:** Generated code MUST be robust. Anticipate potential issues like division by zero (e.g., calculating accuracy for classes with no samples), invalid array indexing (check shapes before indexing), and type mismatches between NumPy and PyTorch. Use safe-guards like adding small epsilons (e.g., 1e-8) to denominators.\n"
            "- **Valid Syntax:** Code MUST be syntactically correct Python. Use `elif` for else-if conditions."
        )

    def get_prompt(self, parent_dict=None):
        """
        Generates the complete prompt for the LLM.
        Optionally includes parent code if parent_dict is provided.
        """
        parent_code_section = ""
        if parent_dict and isinstance(parent_dict, dict):
            hica_parent = parent_dict.get(
                "hica", "# No parent HICA code provided")
            ghop_parent = parent_dict.get(
                "ghop", "# No parent GHOP code provided")
            eoss_parent = parent_dict.get(
                "eoss", "# No parent EOSS code provided")
            parent_code_section = f"""**You should modify the following parent code:**

{self.HICA_START_DELIMITER}
{hica_parent}
{self.HICA_END_DELIMITER}

{self.GHOP_START_DELIMITER}
{ghop_parent}
{self.GHOP_END_DELIMITER}

{self.EOSS_CAT_TARGETS_START_DELIMITER}
{eoss_parent}
{self.EOSS_CAT_TARGETS_END_DELIMITER}

"""

        # Emphasize delimited output
        prompt = f"""**Your goal is to output code for three components (HICA, GHOP, EOSS) enclosed in specific delimiters, for use within the LOS framework.**

{self.get_task()}

{self.get_component_details()}

{self.get_other_inf()}

{parent_code_section}
{self.get_output_format_instructions()}

**Reminder: Generate ONLY the delimited code blocks as the response.**
"""
        return prompt

    # --- NEW: Individual Prompt Generation Methods ---

    def get_hica_mutation_prompt(self, parent_hica_code=None):
        """Generates prompt specifically for HICA (Pla class) mutation.
           Inspired by the successful individual HICA evolution prompts.
        """
        # Borrowing structure/wording from individual hica prompts
        # Re-use general task description? Or make HICA specific?
        task_description = self.get_task()
        # Let's keep it specific for now.
        task_description = ("The task is to implement the Hierarchical Impact Calibration Adjustment (HICA) logic, "
                            "similar to the `Pla` class found in the MBACK module of the MOOSF framework. "
                            "This class dynamically weights losses based on accuracy similarity.")

        class_name = "Pla"  # Explicitly state the required class name

        input_description = (
            f"The `{class_name}` class needs an `__init__` method and two main methods: `update` and `pla`.\n" +
            f"- `__init__(self, num_tasks)`: Initializes `self.beta` as a numpy array of ones with size `num_tasks`.\n" +
            f"- `update(self, acc_s)`: Takes `acc_s` (list of **PyTorch tensors**: [strategy1_acc_vector, ..., strategyN_acc_vector, reference_acc_vector]). " +
            f"**Goal:** Update the `self.beta` weights based on the similarity between each strategy's accuracy vector (`acc_s[i]`) and the reference accuracy vector (`acc_s[-1]`). Higher similarity should result in a higher beta value for that strategy. You might consider using dot product, cosine similarity, or other relevant metrics. Ensure numerical stability (e.g., add epsilon like 1e-8 if dividing by norms). The `self.beta` array should be updated IN PLACE. **Handle potential type mismatches if mixing NumPy and PyTorch operations (e.g., use .item() if assigning a tensor scalar to a numpy element).**\n" +
            f"- `pla(self, losses)`: Takes `losses` (list of PyTorch tensors). Returns a new list `weighted_losses` where `weighted_losses[i] = self.beta[i] * losses[i]`. " +
            f"**CRITICAL: Ensure the multiplication handles potential type differences between `self.beta[i]` (numpy) and `losses[i]` (tensor). Convert `self.beta[i]` to a tensor matching `losses[i].dtype` and `losses[i].device` if necessary BEFORE multiplying.**"
        )

        output_description = "The `pla` method should return a list of weighted PyTorch loss tensors (`weighted_losses`)."

        constraints = (
            f"Implement the entire `{class_name}` class in Python using NumPy and PyTorch. "
            f"Include necessary imports (`import torch`, `import numpy as np`) at the top of the class code. "
            f"The class should store `beta` as a NumPy array internally. "
            f"Ensure correct type handling between NumPy arrays and PyTorch tensors as specified, especially when updating `self.beta` and calculating `weighted_losses`. "
            f"Focus on achieving the goal of the `update` method effectively and correctly implementing the `pla` method. "
        )

        parent_section = ""
        if parent_hica_code:
            parent_section = f"\n**Modify the following parent `{class_name}` class code:**\n```python\n{parent_hica_code}\n```\n"
        else:
            parent_section = f"\n**Generate the `{class_name}` class from scratch.**\n"

        # Construct the final prompt, mirroring the individual structure
        prompt = f"""{task_description}

Required Class Name: {class_name}

Method Inputs & Implementation Details:
{input_description}

Method Outputs:
{output_description}

Other Requirements & Constraints:
{constraints}
{parent_section}
**CRITICAL: Provide ONLY the complete Python code for the '{class_name}' class, including the `class {class_name}:` line and all specified methods. Ensure correct indentation and type handling. Do NOT include any example usage or explanations outside the code.**
"""
        return prompt

    def get_ghop_mutation_prompt(self, parent_ghop_code=None):
        """Generates prompt specifically for GHOP (_project_conflicting method body) mutation."""
        task_context = "Your task is to provide the **Python code for the method body** of the `_project_conflicting` method (part of the `PCGrad` class), for use within the **LOS framework** and evaluated via `main_stage2.py`."
        signature = "`_project_conflicting(self, grads, has_grads, shapes=None)`"
        inputs_desc = "**INPUTS:** `grads` (list of flattened 1D tensors), `has_grads` (a **LIST** of **boolean PyTorch TENSORS**, one tensor per parameter group, indicating parameter sharing)."
        goal_description = (
            "**Goal:** The method should identify pairs of conflicting gradients in the `grads` list (e.g., based on negative dot product). "
            "It should then modify these conflicting gradients to reduce or resolve the conflict. A common technique is projecting one gradient onto the normal plane of the other (`g_i -= dot(g_i, g_j) * g_j / (norm(g_j)**2 + 1e-8)`), but other approaches might be considered. "
            "After resolving conflicts, the potentially modified gradients (let's call them `processed_grads`) need to be merged into a single `merged_grad` tensor. "
            "**Gradient Merging Logic:**"
            "  a. Initialize `merged_grad = torch.zeros_like(processed_grads[0])`."
            "  b. Stack the boolean masks: `combined_has_grads = torch.stack(has_grads)`."
            "  c. **If `self._reduction == 'mean'`:**"
            "     i. Calculate `sum_val = combined_has_grads.sum(dim=0).float()`."
            "     ii. **Inside a loop** iterating through `processed_grads` (index `i`): Add the contribution of each gradient, applying the mask: `merged_grad += processed_grads[i] * combined_has_grads[i].bool()` (Multiplication handles masking)."
            "     iii. **AFTER the loop**, perform the normalization: `merged_grad /= (sum_val + 1e-8)`. **DO NOT perform division inside the loop.**"
            "  d. **Else if `self._reduction == 'sum'`:**"
            "     i. **Inside a loop** iterating through `processed_grads` (index `i`): Add the contribution of each gradient, applying the mask: `merged_grad += processed_grads[i] * combined_has_grads[i].bool()`."
            "  e. Handle other potential reduction types or raise an error if needed."
            # Existing CRITICAL sections remain important:
            "**CRITICAL `has_grads` HANDLING:** `has_grads` is a **LIST** of boolean tensors. You **CANNOT** use `if has_grads[i]:` or call methods like `.sum()` directly on the list itself. Operations like summing across parameter groups require first combining the tensors in the list into a single tensor (e.g., `combined_has_grads = torch.stack(has_grads)`), and then operating on `combined_has_grads`. The tensors *inside* the list (`has_grads[i]`) are **BOOLEAN**; ensure you use appropriate operations. **DO NOT use bitwise operators like `&` or `|` directly between tensors from this list (e.g., `has_grads[i] & has_grads[j]`).** If you need to check if both groups `i` and `j` have *any* gradients, use `if has_grads[i].any() and has_grads[j].any():`. "
            "**VERY IMPORTANT - BOOLEAN INDEXING:** When using `has_grads` or `combined_has_grads` (or any tensor derived from them) to **index** another tensor (e.g., `some_tensor[mask]`), the `mask` tensor **MUST** have `dtype=torch.bool`. If you are unsure about the dtype of your mask tensor, explicitly convert it using `.bool()` **before** using it for indexing (e.g., `some_tensor[mask.bool()]`). This is crucial during the gradient merging step if you select elements based on `combined_has_grads`."
            "**CRITICAL DIVISION HANDLING:** When merging gradients, if you need to normalize by the number of gradients contributing (e.g., dividing by a sum derived from `has_grads`, let's call it `sum_val`), be aware that the sum `sum_val` **can be zero** for some elements. You **MUST** handle this case to avoid division by zero. Either add a small epsilon to the denominator (e.g., `sum_val.float() + 1e-8`) or explicitly check if the sum is zero before performing the division. **Crucially, the divisor (e.g., `sum_val + 1e-8`) MUST have the same shape as `merged_grad` (a flat tensor) for element-wise division. DO NOT reshape the divisor using `.view(-1, 1)` or similar operations that change its dimensionality before the division.** **If using `torch.where`, ensure the condition tensor is boolean (`dtype=torch.bool`).**"
        )
        output_requirement = "The final line **MUST** be `return merged_grad`. The `merged_grad` returned should be a **single, flattened PyTorch tensor**. This line is **ESSENTIAL** and cannot be omitted."
        # --- ADDED CONSTRAINT --- #
        requirements = ("Include necessary imports (e.g., `import torch`, `import copy`, maybe `import random`) at the start of the method body code block. "
                        "**CRITICAL TYPE CONSTRAINT:** Perform ALL calculations using **PyTorch tensor operations ONLY**. If creating new tensors (e.g., from constants), ensure they use the correct PyTorch dtype (e.g., `torch.float32` or matching the input `grads` dtype) and are on the correct device (e.g., `grads[0].device`). **DO NOT use NumPy dtypes (like `np.float64`) in `torch.tensor()` or `.to()` calls. Use `torch.float32`, `torch.float64`, etc., instead.**"
                        "**CRITICAL RETURN CONSTRAINT:** The method body **MUST** calculate and assign a valid flattened PyTorch tensor to the variable `merged_grad`. **Before the final `return merged_grad` line, you SHOULD add an assertion like `assert isinstance(merged_grad, torch.Tensor) and merged_grad is not None, \'merged_grad must be a valid Tensor before returning\'`** to guarantee a valid tensor is returned. **DO NOT RETURN `None`**. Handle potential edge cases (like empty `grads` list) gracefully, ensuring `merged_grad` is still initialized (e.g., to zeros) and returned.")
        # --- END ADDED CONSTRAINT --- #
        parent_section = ""
        if parent_ghop_code:
            parent_section = f"\n**Modify the following parent method body code:**\n```python\n{parent_ghop_code}\n```\n"
        else:
            parent_section = f"\n**Generate the method body code from scratch.**\n"

        prompt = f"""{task_context}
Method Signature: {signature}
{inputs_desc}

{goal_description}

{output_requirement}

{requirements}
{self.get_other_inf()} # Keep general info like evaluation context
{parent_section}
**CRITICAL: Output ONLY the Python code for the method body (lines inside the function, correctly indented), including necessary imports at the beginning of the block. Do NOT include the method definition line `def _project_conflicting(...)`. Do not add explanations or markdown.**
"""
        return prompt

    def get_eoss_mutation_prompt(self, parent_eoss_code=None):
        """Generates prompt specifically for EOSS (Weight_acc class) mutation."""
        # Define the class name here
        class_name = "Weight_acc"

        task_description = ("The task is to implement the Evolving Optimal Strategy Selection (EOSS) logic, "
                            f"similar to the `{class_name}` class, for use within the **LOS framework** and evaluated via `main_stage2.py`."
                            "**Note:** Evaluation in `main_stage2.py` primarily tests validation performance (via `valid_eoss`, likely using `cat_targets`).")

        input_description = (
            f"The `{class_name}` class needs at least these methods:\\n" +
            # Use torch tensor for potential GPU use
            f"- `__init__(self, num_class, tasks)`: Initialize attributes like `self.tasks` (list of task names), `self.num_class`, `self.max_weight_task` (list to store the best task name per class), and `self.weigh_save_list` (dictionary, potentially including a PyTorch tensor for conflict logging like `torch.zeros([400, 2])`).\\n" +
            f"- `update(self, predictions_dict, targets_np)`:\\n"
            # --- MODIFIED Update Logic ---
            f"     **INPUTS:** `predictions_dict` (dict[str, np.ndarray]), where values can be **EITHER LOGITS** (`[N, C]`) **OR PREDICTED INDICES** (`[N,]`). `targets_np` (np.ndarray `[N,]`).\\n"
            f"     **GOAL:** Update `self.max_weight_task` based on per-class accuracy.\\n"
            f"     **IMPLEMENTATION STEPS:**\\n"
            f"     a. **Get Predicted Indices:** Create `predicted_indices = {{}}`. Loop `task_name, preds_or_logits in predictions_dict.items()`:\\n"
            f"        Check `preds_or_logits.ndim`. \\n"
            f"        If `ndim == 2` (Logits): `indices = np.argmax(preds_or_logits, axis=1)`\\n"
            f"        If `ndim == 1` (Indices): `indices = preds_or_logits` \\n"
            f"        Else (Error): Handle appropriately (e.g., raise error or log warning). \\n"
            f"        Store `predicted_indices[task_name] = indices` (Shape MUST be `[N,]`).\\n"
            f"     b. **Calculate Per-Class Accuracies:** Initialize `task_class_accuracies = np.zeros((len(self.tasks), self.num_class))`."
            f"     c. Loop through tasks (index `task_idx`, name `task_name`):\\n"
            f"        `task_preds_indices = predicted_indices[task_name]` # Get predicted indices (Shape [N,])\\n"
            f"        Loop through classes `c` from 0 to `self.num_class - 1`:\\n"
            f"           `class_mask = (targets_np == c)` # Boolean mask (Shape [N,])\\n"
            f"           `total_samples_in_class = np.sum(class_mask)`\\n"
            f"           if total_samples_in_class > 0:\\n"
            f"               # **CRITICAL BOOLEAN INDEXING:** Use the **full** index array and mask.\\n"
            # Shape check remains
            f"               if class_mask.shape[0] != task_preds_indices.shape[0]:\\n"
            f"                   print(f'Error: Shape mismatch update task {{task_name}} class {{c}}. Mask: {{class_mask.shape}}, Preds: {{task_preds_indices.shape}}. Acc=0.')\\n"
            f"                   `class_accuracy = 0.0`\\n"
            f"               else:\\n"
            f"                   `preds_for_this_class = task_preds_indices[class_mask]` # Index the [N,] array with [N,] mask\\n"
            f"                   `correct_preds_for_this_class = np.sum(preds_for_this_class == c)`\\n"
            f"                   `class_accuracy = correct_preds_for_this_class / total_samples_in_class`\\n"
            f"           else:\\n"
            f"               `class_accuracy = 0.0`\\n"
            f"           `task_class_accuracies[task_idx, c] = class_accuracy`\\n"
            f"     d. **Update `max_weight_task`:** (Same as before) Loop `c`, find `best_task_idx = np.argmax(task_class_accuracies[:, c])`, assign `self.max_weight_task[c] = self.tasks[best_task_idx]`. **Ensure `self.max_weight_task` has no `None` values.**\\n"
            # --- END MODIFIED Update Logic ---
            f"- `cat_out(self, logits_dict)`:\n"
            f"     **INPUTS:** `logits_dict` (dict mapping task names to **logit TENSORS** [batch_size, num_class]).\n"
            f"     **GOAL:** Combine the logits based on internal weights (e.g., `self.weigh`).\n"
            f"     **IMPLEMENTATION EXAMPLE:** `return sum(self.weigh[t].to(logits_dict[t].device) * torch.softmax(logits_dict[t], dim=1) for t in self.tasks)` (Ensure weights are moved to the correct device before multiplication). Initialize `self.weigh` in `__init__` if using this approach.\n"
            f"     **OUTPUT:** A single tensor of combined probabilities [batch_size, num_class].\n"
            f"- `cat_targets(self, logits, targets, epoch)`: **Input `logits` IS a dictionary mapping task names to logit TENSORS. Input `targets` is the tensor of TRUE CLASS INDICES.**\\n"
            f"   1. **Get Predictions:** First, calculate the predicted class index for each task for all samples: `predictions = {{task_name: torch.argmax(task_tensor, dim=1) for task_name, task_tensor in logits.items()}}`.\\n"
            f"   2. **Determine Final Class Index per Sample:** Iterate through each sample `i` from 0 to batch_size-1.\\n"
            f"      a. Check for consistency: Do all tasks in `predictions` agree on the predicted class for sample `i`? \\n"
            f"      b. **If Consistent:** The final determined class index for sample `i`, let's call it `final_class_index`, is the agreed-upon prediction (e.g., `predictions[self.tasks[0]][i]`).\\n"
            f"      c. **If Conflicting:** Use the **true class** `true_class = targets[i].item()` **ONLY** to look up the best performing task: `best_task = self.max_weight_task[true_class]` (handle potential None `best_task` by using a fallback like `self.tasks[0]`). The final determined class index for sample `i` (`final_class_index`) is the **prediction made by that `best_task`**, i.e., `final_class_index = predictions[best_task][i]`.\n"
            f"      d. **ABSOLUTE RULE:** The value assigned to `final_class_index` in steps 2b and 2c **MUST** come from the `predictions` dictionary. **NEVER** assign `final_class_index = targets[i]` or use `targets[i]` in any way to construct the final index value itself.\n"
            f"      e. **CRITICAL ASSIGNMENT:** Ensure that inside the loop (step 2), `final_class_index` is assigned a valid integer value (predicted class index) in **BOTH** the consistent case (step 2b) and the conflicting case (step 2c). Failure to assign in all paths will cause an `UnboundLocalError`. Initialize `final_class_index = -1` (or another indicator) at the start of the loop if needed for debugging, but ensure it's overwritten.\n"
            f"   3. **Construct Output Tensor:** Create the final one-hot tensor `output_tensor` of shape `[batch_size, self.num_class]`. For each sample `i`, set the element corresponding to its `final_class_index` to 1.0 (e.g., `output_tensor[i, final_class_index] = 1.0`).\n"
            f"   4. **Log Conflicts:** Correctly log the number of consistent and conflicting samples into `self.weigh_save_list['conflict'][epoch, 0/1]` using `.item()`.\n"
            f"   5. **Return:** Return the `output_tensor`. Ensure it has `dtype=float` and is on the **same device** as the input `targets` tensor (use `.to(targets.device)` if needed).\\n" +
            f"- `cuda(self)`: Method to move internal tensors (if any) to CUDA device. **IMPORTANT: Before calling `.cuda()` on any attribute (e.g., values in `self.weigh_save_list`), check if it's a PyTorch tensor using `isinstance(value, torch.Tensor)`. Skip non-tensors.**\\n"
        )  # Closing parenthesis for input_description

        output_description = f"The `cat_targets` method should return a **single PyTorch tensor** representing the final fused predictions (one-hot encoded), **shape `[batch_size, self.num_class]`**, **dtype `float`**, and on the **same device as `targets`**."

        constraints = (
            f"Implement the entire `{class_name}` class in Python using NumPy and PyTorch. "
            f"Include necessary imports (`import torch`, `import numpy as np`) at the top of the class code block. "
            f"Ensure the class correctly maintains state (like `self.max_weight_task`) between calls to `update` and `cat_targets`. "
            f"Focus on robustly implementing the logic for `update` (determining best task, **no Nones in result**, **handling 1D/2D logits for argmax**) and `cat_targets` (using best task's **prediction** for conflicts, **handling potential None from lookup**, logging, **correct output shape/dtype/device**)."
            f"**CRITICAL SIGNATURE CONSTRAINT:** The `cat_targets` method definition line **MUST** exactly match `def cat_targets(self, logits, targets, epoch):`. Do not omit the `epoch` parameter."
            f"**CRITICAL CONSTRAINT 1:** The `cat_targets` method **MUST NOT** use the values from the input `targets` tensor to directly determine the class index for the final output one-hot tensor. The determination of the final class index for each sample must solely depend on the `predictions` derived from `logits` and the task selection logic using `self.max_weight_task`. Using `targets` values directly in the output construction will invalidate the evaluation."
            f"**CRITICAL CONSTRAINT 2:** Ensure the returned one-hot tensor correctly reflects the determined `final_class_index` for each sample based ONLY on the model's predictions and the EOSS selection logic."
            f"**CRITICAL PYTHON INDENTATION:** All code within the class methods (`__init__`, `update`, `cat_targets`, `cuda`) **MUST** be correctly indented. Incorrect indentation will lead to `SyntaxError`."
            f"**CRITICAL: The generated code MUST be pure, valid Python syntax. DO NOT include any markdown formatting (like ```), comments explaining the code, or any text other than the Python code itself within the code block.**"
        )  # Closing parenthesis for constraints

        parent_section = ""
        if parent_eoss_code:
            parent_section = f"\n**Modify the following parent `{class_name}` class code:**\n```python\n{parent_eoss_code}\n```\n"
        else:
            # Use the defined class_name here
            parent_section = f"\n**Generate the `{class_name}` class from scratch.**\n"

        # CRITICAL: Ensure the final output is just the class code block
        prompt = f"""**Task:** Evolve the EOSS logic by implementing the **complete** Python `{class_name}` class.

{task_description}

**Required Class Structure and Methods:**
{input_description}

**Expected Output from cat_targets:**
{output_description}

**Constraints & Considerations:**
{constraints}

{parent_section}

**Output Requirement:** Provide **only** the complete Python code for the `{class_name}` class, starting with `import ...` and `class {class_name}:`. Do **not** include any other text, explanations, or markdown formatting before or after the code block.
"""
        return prompt


# Example usage (optional)
if __name__ == '__main__':
    # Example: Create an instance (assuming original code snippets are defined)
    # Define some dummy original code for testing
    original_hica_code_str = "class Pla:\n    pass"
    original_ghop_body_str = "pass"
    # --- Need the actual original Weight_acc class code ---
    # Read from file or define here for standalone testing
    try:
        eoss_orig_file = os.path.abspath(os.path.join(os.path.dirname(
            __file__), '..', '..', '..', 'torch-MOOSF', 'MultiBackward', 'ACCFun', 'Mul_same_task.py'))
        with open(eoss_orig_file, 'r', encoding='utf-8') as f:
            eoss_content = f.read()
        # Extract the Weight_acc class (simple extraction based on 'class Weight_acc:' start)
        class_start = eoss_content.find("class Weight_acc:")
        # Find the end (assuming next class/def or end of file at base indent) - This is approximate!
        class_end = len(eoss_content)  # Default to end
        next_block_indent = -1
        if class_start != -1:
            lines = eoss_content[class_start:].splitlines()
            base_indent = len(lines[0]) - len(lines[0].lstrip())
            for i in range(1, len(lines)):
                line = lines[i]
                if line.strip() and (len(line) - len(line.lstrip())) <= base_indent:
                    next_block_indent = i
                    break
            if next_block_indent != -1:
                original_eoss_class_str = "\n".join(lines[:next_block_indent])
            else:
                original_eoss_class_str = "\n".join(lines)

        else:
            original_eoss_class_str = "class Weight_acc:\n    pass # Original not found"

    except Exception as e:
        print(f"Warning: Could not load original EOSS code for example: {e}")
        original_eoss_class_str = "class Weight_acc:\n    pass # Error loading"
    # -----------------------------------------------------

    prompts = PromptsCombined()
    # Test combined prompt
    # combined_prompt = prompts.get_prompt()
    # print("--- Combined Prompt ---")
    # print(combined_prompt)
    # print("\n" + "-"*80 + "\n")

    # Test individual HICA prompt
    hica_prompt = prompts.get_hica_mutation_prompt(original_hica_code_str)
    print("--- HICA Mutation Prompt ---")
    print(hica_prompt)
    print("\n" + "-"*80 + "\n")

    # Test individual GHOP prompt
    ghop_prompt = prompts.get_ghop_mutation_prompt(original_ghop_body_str)
    print("--- GHOP Mutation Prompt ---")
    print(ghop_prompt)
    print("\n" + "-"*80 + "\n")

    # Test individual EOSS prompt (using extracted code)
    eoss_prompt = prompts.get_eoss_mutation_prompt(original_eoss_class_str)
    print("--- EOSS Mutation Prompt ---")
    print(eoss_prompt)
    print("\n" + "-"*80 + "\n")

    # Test combined prompt with parent code
    parent_dict = {
        "hica": original_hica_code_str,
        "ghop": original_ghop_body_str,
        "eoss": original_eoss_class_str
    }
    combined_parent_prompt = prompts.get_prompt(parent_dict)
    print("--- Combined Prompt with Parent ---")
    print(combined_parent_prompt)
    print("\n" + "-"*80 + "\n")
