# -*- coding: utf-8 -*-
import time
import numpy as np
import torch
import torch.nn as nn
import traceback
import sys
import os
import re
import shutil
import tempfile
import subprocess
import json
from datetime import datetime  # Import datetime for timestamps
import textwrap
import threading
# Assuming prompts file is in the same dir
from moosf_combined_prompts import PromptsCombined

# --- Configuration ---
# MOOSF_PROJECT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..', 'torch-MOOSF'))
# Target files
# HICA_TARGET_FILE = os.path.join(MOOSF_PROJECT_DIR, 'MultiBackward', 'MBACK.py')
# GHOP_TARGET_FILE = os.path.join(MOOSF_PROJECT_DIR, 'MultiBackward', 'GradFun', 'pcgrad.py')
# EOSS_TARGET_FILE = os.path.join(MOOSF_PROJECT_DIR, 'MultiBackward', 'ACCFun', 'Mul_same_task.py')
# MAIN_SCRIPT_PATH = os.path.join(MOOSF_PROJECT_DIR, 'main.py')
# Training Arguments - Use MOOSF flag to enable all components
# TRAIN_ARGS = [
#     '--dataset', 'cifar100',
#     '--imb_ratio', '100',
#     '--num_max', '100',
#     '--network', 'resnet32',
#     '--epochs', '50', # Current epoch setting-
#     '--ce',
#     '--bs',
#     '--MOOSF',      # Enable pla, pcg, out_cut
#     '--batch-size', '64',
# ]
FITNESS_PENALTY = 10000.0
EVALUATION_TIMEOUT = 8000
# -----------------------


# --- Helper: Function to inject a complete class definition ---
# Similar to _inject_hica, but more generic
def inject_class_definition(target_file, class_name, class_code):
    """Replaces the entire definition of a class in a target file."""
    # --- REMOVED DEBUG: worker_id retrieval --- #

    try:
        # --- REMOVED DEBUG: Log the full code being injected --- #

        with open(target_file, "r") as f:
            lines = f.readlines()
        # --- REMOVED DEBUG: Print file read success --- #

        start_line = -1
        end_line = -1
        base_indent_level = -1
        class_def_pattern = re.compile(
            r"^(\s*)class\s+" + re.escape(class_name) + r"\b")

        # Find the start of the class definition
        for i, line in enumerate(lines):
            match = class_def_pattern.match(line)
            if match:
                start_line = i
                base_indent_level = len(match.group(1))
                # --- REMOVED DEBUG: Print start found --- #
                break

        if start_line == -1:
            # --- REMOVED DEBUG: Class not found --- #
            print(
                f"Error inject_class: Class '{class_name}' definition not found in {target_file}")
            return False

        # Find the end of the class definition (next line with same or less indent)
        # Start searching from the line *after* the class definition
        for i in range(start_line + 1, len(lines)):
            line_strip = lines[i].lstrip()
            # Skip empty lines and comments
            if not line_strip or line_strip.startswith("#"):
                continue
            current_indent = len(lines[i]) - len(line_strip)
            if current_indent <= base_indent_level:
                end_line = i
                # --- REMOVED DEBUG: Print end found --- #
                break
        # If no line with same/less indent found, assume class goes to end of file
        if end_line == -1:
            end_line = len(lines)
            # --- REMOVED DEBUG: Print end assumed at EOF --- #

        # Prepare injected code lines
        # --- MODIFICATION START: Clean, dedent, and apply CORRECT relative indentation --- #
        try:
            # Dedent the code first
            dedented_code = textwrap.dedent(class_code)
        except Exception as dedent_e:
            # --- REMOVED DEBUG: Dedent warning --- #
            # print(f"[Worker {worker_id}] Warning inject_class: Failed to dedent code: {dedent_e}. Using original code.")
            dedented_code = class_code  # Use original if dedent fails

        # Split lines
        dedented_lines = dedented_code.splitlines()

        injected_lines = []
        for line in dedented_lines:
            if line.strip():  # Process only non-blank lines
                # Calculate relative indent within the dedented block
                relative_indent_level = len(line) - len(line.lstrip(' '))
                # Calculate the absolute indent needed in the target file
                absolute_indent_level = base_indent_level + relative_indent_level
                # Apply absolute indent and add newline
                injected_lines.append(
                    " " * absolute_indent_level + line.lstrip(' ') + "\n")
            else:
                injected_lines.append("\n")  # Preserve blank lines

        if not any(line.strip() for line in injected_lines):
            # --- REMOVED DEBUG: Empty code --- #
            # print(f"[Worker {worker_id}] Error inject_class: Provided class code for '{class_name}' is empty after cleaning. Returning False.")
            return False
        # Ensure the very last line has a newline if it's not blank
        if injected_lines and injected_lines[-1].strip() and not injected_lines[-1].endswith("\n"):
            injected_lines[-1] += "\n"
        # --- MODIFICATION END --- #

        # Construct the new file content
        new_content = lines[:start_line] + injected_lines + lines[end_line:]

        # --- REMOVED DEBUG: Injection preview --- #

        # Write the modified content back to the file
        # print(f"[Worker {worker_id}] DEBUG inject_class: Injecting class '{class_name}' code into {target_file} replacing lines {start_line + 1} to {end_line}. Attempting write.")
        with open(target_file, "w") as f:
            f.writelines(new_content)
        # --- REMOVED DEBUG: Write success --- #
        # print(f"[Worker {worker_id}] DEBUG inject_class: Write successful. Returning True.")
        return True

    except FileNotFoundError:
        print(f"Error inject_class: Target file not found: {target_file}")
        return False
    except Exception as e:
        print(
            f"Error inject_class: An unexpected error occurred during injection for {class_name} into {target_file}: {e}")
        traceback.print_exc()
        return False


# --- End Helper ---


class MOOSFCombinedProblem:
    def __init__(
        self,
        prompts_instance,
        output_dir,
        target_gpus: list[int] | None = None,
        project_dir: str = None,
        evaluation_script: str = "main_stage2.py",
        evaluation_args: list[str] | None = None,
        # --- ADDED for new directory naming ---
        population_size: int = 0,  # Default to 0 if not provided
        total_generations: int = 0,  # Default to 0 if not provided
        modules_to_evolve: list[str] | None = None
    ):
        """
        Initializes the combined MOOSF evolution problem within the LOS framework.
        :param prompts_instance: Instance of PromptsCombined.
        :param output_dir: The main output directory for the EOH experiment.
        :param target_gpus: Optional list of specific GPU IDs to use for evaluation.
        :param project_dir: Absolute path to the LOS project root directory.
        :param evaluation_script: Name of the evaluation script (e.g., 'main_stage2.py').
        :param evaluation_args: List of command-line arguments for the evaluation script.
        """
        self.prompts = prompts_instance
        self.output_dir = output_dir

        # --- Modification for unique evaluation_logs directory name ---
        # Extract the experiment tag from the output_dir.
        # It's assumed that output_dir is now path/to/results_los_combined/experiment_tag
        current_experiment_tag = os.path.basename(output_dir.rstrip(
            os.sep))  # Ensure no trailing slash affects basename
        eval_log_dirname = f"evaluation_logs_{current_experiment_tag}"
        self.eval_log_dir = os.path.join(self.output_dir, eval_log_dirname)
        # --- End Modification ---

        # --- Ensure the base evaluation log directory exists --- #
        if not os.path.exists(self.eval_log_dir):
            try:
                os.makedirs(self.eval_log_dir)
                print(
                    f"Created base evaluation log directory: {self.eval_log_dir}")
            except OSError as e:
                print(
                    f"ERROR: Could not create base evaluation log directory {self.eval_log_dir}: {e}"
                )
                # Decide how to handle this - maybe raise error or disable logging?
                # For now, just print error and continue, subsequent steps might fail.
                self.eval_log_dir = None  # Indicate failure
        # --- End Ensure --- #

        # Store LOS specific info
        if not project_dir or not os.path.isdir(project_dir):
            raise ValueError(
                "Valid 'project_dir' (path to LOS project) must be provided."
            )
        self.project_dir = project_dir
        self.evaluation_script = evaluation_script
        self.evaluation_args = evaluation_args if evaluation_args else []
        self.main_script_path = os.path.join(
            self.project_dir, self.evaluation_script)

        # --- ADDED for new directory naming ---
        self.population_size = population_size
        self.total_generations = total_generations
        self.modules_to_evolve = modules_to_evolve if modules_to_evolve is not None else []
        # --- END ADDED ---

        # Define target file paths relative to the LOS project dir
        self.hica_target_relpath = os.path.join("MultiBackward", "MBACK.py")
        self.ghop_target_relpath = os.path.join(
            "MultiBackward", "GradFun", "pcgrad.py")
        self.eoss_target_relpath = os.path.join(
            "MultiBackward", "ACCFun", "Mul_same_task.py"
        )

        # Store or set default target GPU IDs
        if (
            target_gpus
            and isinstance(target_gpus, list)
            and len(target_gpus) > 0
            and all(isinstance(gpu_id, int) for gpu_id in target_gpus)
        ):
            self.target_gpu_ids = target_gpus
            print(
                f"LOS+MOOSF Problem configured to use specific GPUs: {self.target_gpu_ids}"
            )
        else:
            print(
                f"Warning: Invalid or empty target_gpus ({target_gpus}). Defaulting to GPU 0."
            )
            # Default to GPU 0 if list is invalid or empty
            self.target_gpu_ids = [0]
        # Store the count as well
        self.num_gpus_to_use = len(self.target_gpu_ids)

        # Check paths within the LOS project
        self.hica_target_file = os.path.join(
            self.project_dir, self.hica_target_relpath)
        self.ghop_target_file = os.path.join(
            self.project_dir, self.ghop_target_relpath)
        self.eoss_target_file = os.path.join(
            self.project_dir, self.eoss_target_relpath)

        for path in [
            self.project_dir,
            self.hica_target_file,
            self.ghop_target_file,
            self.eoss_target_file,
            self.main_script_path,
        ]:
            if not os.path.exists(path):
                raise FileNotFoundError(f"Required LOS path not found: {path}")

        # Print statements
        print("- LOS+MOOSF Problem Initialized -")
        print(f"  Project Dir: {self.project_dir}")
        print(f"  Target HICA: {self.hica_target_relpath}")
        print(f"  Target GHOP: {self.ghop_target_relpath}")
        print(f"  Target EOSS: {self.eoss_target_relpath}")
        print(f"  Eval Script: {self.evaluation_script}")
        print(f"  Eval Args: {' '.join(self.evaluation_args)}")
        print(f"  Eval Timeout: {EVALUATION_TIMEOUT}s")
        print("  Using temporary code copies for parallel evaluation.")
        print(
            f"  Using specific GPUs {self.target_gpu_ids} assigned cyclically.")

    def _prepare_evaluation_environment(self, individual_dict: dict) -> tuple[str, str, str, list]:
        """
        Creates a temporary directory, copies the LOS project, and prepares eval args.
        Returns (temp_dir_path, temp_project_base, log_file_path, eval_args).
        """
        worker_id = individual_dict.get(
            "worker_id", "N/A")  # Get worker ID if available
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
        unique_id = f"{timestamp}_W{worker_id}_G{individual_dict.get('generation', 'N/A')}"

        # --- Create unique temporary directory ---
        # Use a base temp directory for better organization if desired
        base_temp_dir = "/tmp/eoh_eval_runs"  # Example base temp dir
        os.makedirs(base_temp_dir, exist_ok=True)
        temp_dir_path = tempfile.mkdtemp(
            prefix=f"eoh_eval_W{worker_id}_", dir=base_temp_dir)
        temp_project_base = os.path.join(temp_dir_path, os.path.basename(
            self.project_dir))  # e.g., /tmp/.../LOS

        # --- Copy LOS project to temp directory ---
        # Use the stored absolute project_dir path
        los_source_path = self.project_dir
        try:
            # Ensure source exists before copying
            if not os.path.isdir(los_source_path):
                raise FileNotFoundError(
                    f"LOS source directory not found at: {los_source_path}")
            # Copy contents into the temp_project_base directory
            shutil.copytree(los_source_path, temp_project_base, symlinks=False,
                            ignore_dangling_symlinks=True, dirs_exist_ok=True)
            # print(f"[Worker {worker_id}] Copied LOS project from '{los_source_path}' to '{temp_project_base}'.")
        except Exception as e:
            print(
                f"[Worker {worker_id}] FATAL ERROR copying LOS project to temp dir '{temp_dir_path}': {e}")
            traceback.print_exc()
            # Cleanup potentially partially created directory
            if os.path.exists(temp_dir_path):
                shutil.rmtree(temp_dir_path)
            raise  # Re-raise the exception to stop the evaluation

        # --- Prepare Evaluation Arguments ---
        # Construct a descriptive base name for the output log directory within self.eval_log_dir
        evolved_modules_in_dict = [k for k in [
            "hica", "ghop", "eoss"] if k in individual_dict]
        # Or indicate which modules were fixed
        evolved_str = "_".join(k.upper(
        ) for k in evolved_modules_in_dict) if evolved_modules_in_dict else "BASELINE"

        # Extract epochs
        epochs_val = 'N/A'
        if '--finetune_epoch' in self.evaluation_args:
            try:
                epochs_idx = self.evaluation_args.index('--finetune_epoch')
                if epochs_idx + 1 < len(self.evaluation_args):
                    epochs_val = self.evaluation_args[epochs_idx + 1]
            except ValueError:  # Should ideally not happen if '--finetune_epoch' is checked
                pass  # Keep epochs_val as 'N/A'
        epochs_str = f"Epochs{epochs_val}"

        # Extract task string (short version, e.g., Taskscebslos)
        actual_tasks = []
        if '--tasks' in self.evaluation_args:
            try:
                tasks_start_index = self.evaluation_args.index('--tasks') + 1
                for i in range(tasks_start_index, len(self.evaluation_args)):
                    arg = self.evaluation_args[i]
                    # Stop if another command-line argument is encountered
                    if arg.startswith('--'):
                        break
                    actual_tasks.append(arg)
            except ValueError:  # Should ideally not happen if '--tasks' is checked
                pass  # Keep actual_tasks empty

        tasks_str_short = "Tasks" + \
            "".join(actual_tasks) if actual_tasks else "TasksN/A"

        # ---- NEW DIRECTORY NAMING LOGIC ----
        # 1. Extract tasks string
        actual_tasks_for_dirname = []
        if '--tasks' in self.evaluation_args:
            try:
                tasks_start_index = self.evaluation_args.index('--tasks') + 1
                for i in range(tasks_start_index, len(self.evaluation_args)):
                    arg = self.evaluation_args[i]
                    if arg.startswith('--'):
                        break
                    actual_tasks_for_dirname.append(arg)
            except ValueError:
                pass  # --tasks not found or malformed
        tasks_str_for_dirname = "+".join(
            actual_tasks_for_dirname) if actual_tasks_for_dirname else "NoTasks"

        # 2. Construct modules string from self.modules_to_evolve (passed in __init__)
        modules_str_for_dirname = "+".join([m.lower() for m in self.modules_to_evolve]
                                           ) if self.modules_to_evolve else "NoModulesEvolved"

        # 3. Construct population and total generations strings
        pop_str_for_dirname = f"pop{self.population_size}"
        totalgen_str_for_dirname = f"totalgen{self.total_generations}"

        # 4. The unique suffix (timestamp, worker_id, current_generation_id)
        #    This unique_id is already defined above: f"{timestamp}_W{worker_id}_G{individual_dict.get('generation', 'N/A')}"
        #    Let's rename it to unique_suffix for clarity in this context.
        unique_suffix = unique_id

        # 5. Combine to form the new output_log_subdir_name
        output_log_subdir_name = f"{tasks_str_for_dirname}_{pop_str_for_dirname}_{totalgen_str_for_dirname}_{modules_str_for_dirname}_{unique_suffix}"
        # ---- END NEW DIRECTORY NAMING LOGIC ----

        # Ensure log dir exists (should have been created in __init__)
        if self.eval_log_dir:
            full_output_log_dir = os.path.join(
                self.eval_log_dir, output_log_subdir_name)
        else:
            # Fallback if eval_log_dir failed to create
            print(
                f"[Worker {worker_id}] Warning: eval_log_dir not set, logging to temp dir: {temp_dir_path}")
            full_output_log_dir = os.path.join(temp_dir_path, "eval_output")
            os.makedirs(full_output_log_dir, exist_ok=True)

        # Assign GPU ID cyclically based on worker ID
        assigned_gpu_id = self.target_gpu_ids[int(
            worker_id) % self.num_gpus_to_use]
        print(f"[Worker {worker_id}] Assigned GPU: {assigned_gpu_id}")

        # Prepare final eval args, adding CUDA_VISIBLE_DEVICES and dynamic output path
        final_eval_args = list(self.evaluation_args)  # Copy base args
        # Add dynamic output path
        final_eval_args.extend(["--out", full_output_log_dir])
        # Explicitly set GPU for the script
        final_eval_args.extend(["--gpu", str(assigned_gpu_id)])

        # Define the main log file path within the output directory
        # Change log file name to avoid collision with logger file handler
        subprocess_log_file_path = os.path.join(
            full_output_log_dir, "subprocess.log")

        # --- Create Metadata File --- #
        metadata = {
            "directory_name": output_log_subdir_name,
            "timestamp_ms": unique_id,
            "evolved_modules_calculated": evolved_str,
            "evaluation_script": self.evaluation_script,
            # Log base args before dynamic additions
            "evaluation_args": self.evaluation_args,
            "assigned_gpu_id": assigned_gpu_id,
            "worker_id": worker_id,
            "generation": individual_dict.get('generation', 'N/A')
            # Add fitness/result later after evaluation
        }

        # --- Add Evolved Code to Metadata ---
        evolved_code_to_store = {}
        possible_evolved_modules = ["hica", "ghop", "eoss"]
        for module_name in possible_evolved_modules:
            # Check if the module was intended for evolution (not fixed)
            if module_name not in self.prompts.fixed_modules_code:
                # Check if the code for this evolved module exists in the individual_dict
                if module_name in individual_dict and individual_dict[module_name]:
                    evolved_code_to_store[f"evolved_{module_name}_code"] = individual_dict[module_name]

        if evolved_code_to_store:
            metadata["evolved_llm_code"] = evolved_code_to_store
        # --- End Add Evolved Code ---

        metadata_file_path = os.path.join(full_output_log_dir, "metadata.json")
        try:
            # Ensure dir exists before writing
            os.makedirs(full_output_log_dir, exist_ok=True)
            with open(metadata_file_path, 'w') as mf:
                json.dump(metadata, mf, indent=4)
        except Exception as meta_e:
            print(
                f"[Worker {worker_id}] Warning: Could not write metadata file {metadata_file_path}: {meta_e}")
        # --- End Metadata --- #

        return temp_dir_path, temp_project_base, subprocess_log_file_path, final_eval_args

    def _extract_code_block(self, code_string, start_marker, end_marker):
        """Extracts a code block between specified markers."""
        try:
            start_index = code_string.index(start_marker) + len(start_marker)
            end_index = code_string.index(end_marker)
            block = code_string[start_index:end_index].strip()
            # Remove the potential header comment line like '# --- Code for ... ---'
            lines = block.splitlines()
            if lines and lines[0].strip().startswith("# ---"):
                block = "\n".join(lines[1:]).strip()
            return block
        except ValueError:
            print(
                f"Error: Could not find markers '{start_marker}' or '{end_marker}'")
            return None

    def _inject_hica(self, hica_code):
        """Injects the Pla class code into MBACK.py."""
        target_file = self.hica_target_file
        try:
            with open(target_file, "r", encoding="utf-8") as f:
                original_content = f.readlines()

            start_line, end_line = -1, -1
            class_def_pattern = re.compile(r"^\s*class\s+Pla\b")

            for i, line in enumerate(original_content):
                if start_line == -1 and class_def_pattern.match(line):
                    start_line = i
                    indent_level = len(line) - len(line.lstrip(" "))
                elif start_line != -1:
                    current_indent = len(line) - len(line.lstrip(" "))
                    is_new_block_start = line.strip() and (
                        line.lstrip().startswith("class ")
                        or line.lstrip().startswith("def ")
                    )
                    if line.strip() and current_indent < indent_level:
                        end_line = i
                        break
                    elif (
                        line.strip()
                        and current_indent == indent_level
                        and is_new_block_start
                    ):
                        end_line = i
                        break

            if start_line != -1 and end_line == -1:
                end_line = len(original_content)

            if start_line == -1 or end_line == -1:
                print(
                    f"Error: Could not find definition of 'class Pla' in {target_file}"
                )
                return False

            # Basic indentation (assuming LLM provides full class with imports)
            # We replace the whole block, so less sensitive to internal LLM indent
            injected_lines = [line + "\n" for line in hica_code.splitlines()]
            if not injected_lines[-1].endswith("\n"):
                injected_lines[-1] += "\n"

            new_content = (
                original_content[:start_line]
                + injected_lines
                + original_content[end_line:]
            )

            print(
                f"DEBUG: Injecting HICA code into {os.path.basename(target_file)} between lines {start_line+1} and {end_line+1}"
            )
            with open(target_file, "w", encoding="utf-8") as f:
                f.writelines(new_content)
            return True
        except Exception as e:
            print(f"Error injecting HICA code into {target_file}: {e}")
            traceback.print_exc()
            return False

    def _inject_method(self, target_file, class_name, method_name, method_code):
        """Injects a method body into a class. (Used for GHOP)."""  # Updated docstring
        # is_ghop = (method_name == "_project_conflicting") # No longer needed for prints
        # is_eoss = (method_name == "cat_targets") # No longer needed for prints

        try:
            print(
                f"DEBUG: Reading {os.path.basename(target_file)} for injecting method '{method_name}'..."
            )
            with open(target_file, "r", encoding="utf-8") as f:
                original_content = f.readlines()

            # --- Find method start and end (Keep existing logic) ---
            start_line, end_line = -1, -1
            method_def_pattern = re.compile(rf"^(\s*)def\s+{method_name}\b\(")
            class_def_pattern = re.compile(rf"^\s*class\s+{class_name}\b")
            base_indent_level, method_indent_level = -1, -1
            in_class = False

            for i, line in enumerate(original_content):
                if not in_class and class_def_pattern.match(line):
                    in_class = True
                    base_indent_level = len(line) - len(line.lstrip(" "))
                    continue
                if not in_class:
                    continue
                match = method_def_pattern.match(line)
                if start_line == -1 and match:
                    start_line = i
                    method_indent_str = match.group(1)
                    method_indent_level = len(method_indent_str)
                    continue
                if start_line != -1:
                    # Ensure this block has consistent indentation (Level 2 within the outer if)
                    current_indent = len(line) - len(line.lstrip(" "))
                    is_not_empty = line.strip() != ""
                    is_less_indented = current_indent < method_indent_level
                    is_new_def_or_class = line.lstrip().startswith(("def ", "class "))
                    if (
                        is_not_empty
                        and (current_indent <= method_indent_level)
                        and (is_less_indented or is_new_def_or_class)
                    ):
                        end_line = i
                        break
            if start_line != -1 and end_line == -1:
                end_line = len(original_content)
            if start_line == -1 or end_line == -1:
                print(
                    f"Error: Could not find method '{method_name}' in class '{class_name}' within {target_file}"
                )
                return False
            # --- End Find method ---

            # --- Pre-process and Indent (Keep existing logic) ---
            code_string_lines = method_code.splitlines(keepends=True)
            first_non_blank = -1
            last_non_blank = -1
            for idx, line in enumerate(code_string_lines):
                if line.strip():
                    if first_non_blank == -1:
                        first_non_blank = idx
                    last_non_blank = idx
            if first_non_blank != -1:
                code_string_lines = code_string_lines[
                    first_non_blank: last_non_blank + 1
                ]
            else:
                code_string_lines = []

            if not code_string_lines:
                print(
                    f"Error: Method code for {method_name} cannot be empty.")
                return False
            code_string_lines = [
                line
                for line in code_string_lines
                if not line.strip().startswith("import ")
            ]
            if not code_string_lines:
                print(
                    f"Error: Method code for {method_name} became empty after removing imports."
                )
                return False

            provided_lines = code_string_lines
            match = re.match(r"^(\s*)", provided_lines[0])
            first_line_indent_str = match.group(1) if match else ""

            injected_lines = []
            body_indent_level = method_indent_level + 4
            body_indent_str = " " * body_indent_level

            # --- REMOVED EOSS ENTRY DEBUG PRINT ---

            for line_idx, line in enumerate(provided_lines):
                line_ending = ""
                if line.endswith("\r\n"):
                    line_ending = "\r\n"
                elif line.endswith("\n"):
                    line_ending = "\n"
                stripped_content = line.strip()
                if not stripped_content:
                    injected_lines.append(line_ending)
                    continue
                stripped_line_l = line.lstrip()
                original_line_indent = len(line) - len(stripped_line_l)
                relative_indent_level = original_line_indent - len(
                    first_line_indent_str
                )
                new_indent_level = body_indent_level + relative_indent_level
                new_indent_str = " " * max(0, new_indent_level)
                injected_lines.append(
                    new_indent_str + stripped_content + line_ending)

            if (
                injected_lines
                and injected_lines[-1].strip()
                and not injected_lines[-1].endswith(os.linesep)
            ):
                preferred_ending = "\n" if os.linesep != "\r\n" else os.linesep
                injected_lines[-1] += preferred_ending
            # --- End Pre-process and Indent ---

            # --- REMOVED DEBUG PRINT BEFORE RETURN LOGIC ---

            # Construct new content
            prefix_lines = original_content[: start_line + 1]
            if prefix_lines and not prefix_lines[-1].endswith(os.linesep):
                prefix_lines[-1] += os.linesep

            new_content = prefix_lines + injected_lines + \
                original_content[end_line:]

            print(
                f"DEBUG: Injecting method '{method_name}' into {os.path.basename(target_file)} between lines {start_line+2} and {end_line+1}"
            )
            with open(target_file, "w", encoding="utf-8") as f:
                f.writelines(new_content)
            return True
        except Exception as e:
            print(
                f"Error injecting method '{method_name}' into {target_file}: {e}")
            traceback.print_exc()
            return False

    def _inject_code(self, individual_dict: dict, temp_project_base: str) -> bool:
        """Injects evolved code into COPIED LOS files."""
        # --- REMOVED DEBUG: Print structure of individual_dict at inject_code entry --- #

        injection_status = {
            "hica": True,
            "ghop": True,
            "eoss": True,
        }  # Default to True (no error)
        at_least_one_evolved = False

        # --- Refined Injection Logic --- #

        # HICA Injection (Rewritten for clarity)
        if "hica" not in self.prompts.fixed_modules_code:  # Check if HICA should be evolved
            hica_code = individual_dict.get("hica")
            worker_id_str = f"[Worker {individual_dict.get('worker_id', 'N/A')}]"

            # --- REMOVED DEBUG: Print the content/length of hica_code --- #

            if hica_code:  # Check if code string exists and is not empty
                # --- Indent this block --- #
                at_least_one_evolved = True
                print(
                    f"{worker_id_str} DEBUG: Attempting to inject EVOLVED HICA code...")
                injection_succeeded = inject_class_definition(
                    os.path.join(temp_project_base, "MultiBackward",
                                 "MBACK.py"),  # Provide absolute path
                    "Pla", hica_code
                )

                if injection_succeeded:
                    print(
                        f"{worker_id_str} DEBUG: inject_class_definition reported SUCCESS for HICA.")
                    injection_status["hica"] = True
                else:  # inject_class_definition returned False
                    print(
                        f"{worker_id_str} Error: inject_class_definition reported FAILURE for HICA.")
                    injection_status["hica"] = False
            else:  # hica_code was None or empty string
                # --- This block is correctly indented relative to the outer 'if' --- #
                print(
                    f"{worker_id_str} Warning: HICA set to evolve, but 'hica' code was missing or empty in individual_dict.")
                injection_status["hica"] = False
        else:  # HICA is fixed
            # --- This block is correctly indented --- #
            print(
                f"[Worker {individual_dict.get('worker_id', 'N/A')}] Info: Skipping HICA injection as it's fixed.")
            injection_status["hica"] = True

        # GHOP Injection (Keep original structure for now, assuming it's correct as GHOP is fixed)
        if (
            "ghop" not in self.prompts.fixed_modules_code
        ):  # Check if GHOP should be evolved
            ghop_code = individual_dict.get("ghop")
            # --- MODIFIED: Check if code exists AFTER potential parsing ---
            if ghop_code:
                at_least_one_evolved = True
                print(
                    f"[Worker {individual_dict.get('worker_id', 'N/A')}] DEBUG: Attempting to inject EVOLVED GHOP code (obtained key: ghop)."
                )
                # --- MODIFIED: Pass the full path to the target file --- #
                target_file_path_ghop = os.path.join(
                    temp_project_base, self.ghop_target_relpath)
                injection_status["ghop"] = self._inject_method(
                    target_file_path_ghop, "PCGrad", "_project_conflicting", ghop_code
                )
                if not injection_status["ghop"]:
                    print(
                        f"Error injecting EVOLVED GHOP code into {target_file_path_ghop}")
            else:
                # This case means GHOP should evolve, but code is missing
                print(
                    "Warning: GHOP set to evolve, but no 'ghop' code found in individual_dict."
                )
                # Failed if expected but not found
                injection_status["ghop"] = False
        else:  # GHOP is fixed
            print("Info: Skipping GHOP injection as it's fixed.")
            injection_status["ghop"] = True  # Fixed is not failure

        # EOSS (cat_targets) Injection
        if (
            "eoss" not in self.prompts.fixed_modules_code
        ):  # Check if EOSS should be evolved
            eoss_code = individual_dict.get("eoss")
            # --- EOSS parsing logic was removed previously, assuming it's fixed ---
            if eoss_code:  # EOSS code provided
                at_least_one_evolved = True
                cat_targets_body = eoss_code.strip()
                if cat_targets_body:  # Code is not empty after strip
                    print(
                        f"[Worker {individual_dict.get('worker_id', 'N/A')}] DEBUG: Attempting to inject EVOLVED EOSS cat_targets body (obtained key: eoss, stripped len={len(cat_targets_body)})."
                    )
                    # --- MODIFIED: Pass the full path to the target file --- #
                    target_file_path_eoss = os.path.join(
                        temp_project_base, self.eoss_target_relpath)
                    injection_status["eoss"] = self._inject_method(
                        target_file_path_eoss, "Weight_acc", "cat_targets", cat_targets_body
                    )
                    if not injection_status["eoss"]:
                        print(
                            f"Error injecting EVOLVED EOSS cat_targets body into {target_file_path_eoss}"
                        )
                else:  # Code is empty after strip
                    # This case means EOSS should evolve, but code is empty after stripping
                    print(
                        "Error: EOSS set to evolve, but received empty code after stripping."
                    )
                    injection_status["eoss"] = False
            else:  # EOSS code NOT provided (but should have been)
                # This case means EOSS should evolve, but code is missing
                print(
                    "Warning: EOSS set to evolve, but no 'eoss' code found in individual_dict."
                )
                # Failed if expected but not found
                injection_status["eoss"] = False
        else:  # EOSS is fixed (Outer else for `"eoss" not in ...`)
            print("Info: Skipping EOSS injection as it's fixed.")
            injection_status["eoss"] = True  # Fixed is not failure
        # --- End Refined Injection Logic --- #

        # Check overall success only if at least one module was supposed to be evolved
        if at_least_one_evolved:
            final_success = all(
                injection_status[module]
                for module in injection_status
                if module not in self.prompts.fixed_modules_code
            )
            if not final_success:
                print(
                    "Critical Error: At least one EVOLVED module injection failed in temp dir."
                )
                return False  # Return False only if an *evolved* module failed
        else:  # If no modules were set to evolve
            final_success = True  # Considered successful if nothing needed evolving

        print(
            f"[Worker {individual_dict.get('worker_id', 'N/A')}] Temp LOS Injection summary: {injection_status} -> Overall Success: {final_success}"
        )
        return final_success  # Return overall success status

    def evaluate(self, individual_dict: dict, **kwargs):
        """Evaluates evolved code using the LOS framework (main_stage2.py)."""
        # --- REMOVED DEBUG: Print structure of received individual_dict --- #

        worker_id = kwargs.get("worker_id", "N/A")  # Useful for logging
        # Add worker_id for internal use
        individual_dict["worker_id"] = worker_id
        individual_dict["generation"] = kwargs.get(
            "generation", "N/A")  # Add generation

        N_OBJECTIVES = 1  # Assuming single objective based on current parsing

        # --- Stage 1: Prepare Environment & Inject Code ---
        temp_dir = None  # Initialize in case of early exit
        log_file_path = "/dev/null"  # Default in case of early failure before assignment
        eval_success = False
        final_fitness_score = [-1e9] * N_OBJECTIVES  # Initialize with penalty

        try:
            temp_dir, temp_project_base, log_file_path, eval_args = self._prepare_evaluation_environment(
                individual_dict)
            injection_success = self._inject_code(
                individual_dict, temp_project_base)

            if not injection_success:
                print(
                    f"[Worker {worker_id}] ERROR: Code injection failed. Evaluation cannot proceed.")
                # Fitness already set to penalty
                return final_fitness_score  # Early exit

            # --- Stage 2: Run LOS Evaluation Script as Subprocess ---
            eval_script_path = os.path.join(
                temp_project_base, self.evaluation_script)
            eval_command = [sys.executable, eval_script_path] + eval_args

            try:
                assigned_gpu_id = self.target_gpu_ids[int(
                    worker_id) % self.num_gpus_to_use]
            except (ValueError, IndexError, TypeError) as e:
                print(
                    f"[Worker {worker_id}] ERROR: Could not determine assigned GPU ID: {e}. Defaulting to 0.")
                assigned_gpu_id = 0

            os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
            process_return_code = -1
            process = None

            def target():  # Defined here to capture necessary variables
                nonlocal process, process_return_code  # Allow modification of outer scope vars
                sub_env = os.environ.copy()
                sub_env["CUDA_VISIBLE_DEVICES"] = str(assigned_gpu_id)
                with open(log_file_path, "w") as log_f:
                    process = subprocess.Popen(
                        eval_command,
                        cwd=temp_project_base,
                        stdout=log_f,
                        stderr=subprocess.STDOUT,
                        text=True,
                        env=sub_env
                    )
                    process.wait()
                    process_return_code = process.returncode

            thread = threading.Thread(target=target)
            thread.start()
            thread.join(timeout=EVALUATION_TIMEOUT)

            if thread.is_alive():
                print(
                    f"[Worker {worker_id}] ERROR: Evaluation subprocess timed out after {EVALUATION_TIMEOUT} seconds. Terminating.")
                if process:
                    try:
                        process.terminate()
                        process.wait(timeout=5)
                    except subprocess.TimeoutExpired:
                        process.kill()
                    except Exception as term_err:
                        print(
                            f"[Worker {worker_id}] Warning: Error during process termination: {term_err}")
                # eval_success remains False, final_fitness_score remains penalty
            elif process_return_code == 0:
                eval_success = True
            else:
                print(
                    f"[Worker {worker_id}] ERROR: Evaluation subprocess failed (Return Code: {process_return_code}). Check log: {log_file_path}")
                # eval_success remains False, final_fitness_score remains penalty

            # --- Stage 3: Parse Results from Log File ---
            if eval_success:
                parsed_metrics = self._parse_los_log(log_file_path)
                if parsed_metrics:
                    best_acc = parsed_metrics.get("Best Acc", -1.0)
                    if best_acc >= 0:
                        current_fitness = best_acc
                        if current_fitness > 0:
                            scaling_exponent = 2.0
                            scaled_fitness = current_fitness ** scaling_exponent
                            print(
                                f"[Worker {worker_id}] Original Best Acc: {current_fitness:.4f}, Scaled Fitness: {scaled_fitness:.4f}")
                            final_fitness_score = [
                                scaled_fitness]  # Store as list
                        else:
                            final_fitness_score = [
                                current_fitness]  # Store as list
                    else:
                        print(
                            f"[Worker {worker_id}] Warning: 'Best Acc' not found or invalid in log. Using penalty.")
                else:
                    print(
                        f"[Worker {worker_id}] Warning: Parsing log file failed. Using penalty.")
            # If not eval_success, final_fitness_score remains the initial penalty

        except Exception as prep_err:  # Broad exception for setup/injection/subprocess stages
            print(
                f"[Worker {worker_id}] FATAL ERROR during evaluation stages: {prep_err}")
            traceback.print_exc()
            # final_fitness_score already initialized to penalty
        finally:
            # --- Stage 4: Cleanup and Return ---
            self._cleanup_temp_dir(temp_dir)

        return final_fitness_score

    def _cleanup_temp_dir(self, temp_dir_path: str):
        """Safely removes the temporary evaluation directory."""
        if temp_dir_path and os.path.exists(temp_dir_path):
            try:
                shutil.rmtree(temp_dir_path)
                # print(f"[Worker ???] Cleaned up temp directory: {temp_dir_path}")
            except Exception as e:
                print(
                    f"[Worker ???] Warning: Failed to remove temp directory '{temp_dir_path}': {e}")
        # else:
            # print(f"[Worker ???] Temp directory already removed or not created: {temp_dir_path}")

    def _parse_los_log(self, log_file_path: str) -> dict | None:
        """
        Parses the LOS evaluation stdout log file to extract key metrics.
        Looks for the 'Best' accuracy line printed during validation epochs.
        Example line: '> [Best ]     Acc:    55.6400 Many:   78.66   Medium: 51.49   Few:    33.63'
        """
        best_acc = -1.0
        best_many = -1.0
        best_medium = -1.0
        best_few = -1.0
        # Accuracy from the very last epoch (if different from best)
        final_acc = -1.0

        # --- MODIFIED: New regex to find the FINAL summary lines --- #
        # Looks for lines like '     best bAcc (test):    55.8700' (handles variable leading whitespace)
        final_bacc_pattern = re.compile(
            r"^\s*(?:└>\s*)?best bAcc \(test\):\s*([\d\.]+)"
        )
        # Optional: Regex to capture final class-wise stats
        final_stats_pattern = re.compile(
            r"^\s*(?:└>\s*)?best statistics:\s*Many:\s*([\d\.]+)\s*Med:\s*([\d\.]+)\s*Few:\s*([\d\.]+)"
        )

        # --- REMOVED Old regex patterns for [Best ] and [Test ] lines --- #

        try:
            with open(log_file_path, "r") as f:
                lines = f.readlines()

            # --- MODIFIED: Search for the final summary lines --- #
            # Search the *entire* log file for the specific final summary lines
            found_bacc = False
            # Search from the end is usually faster for summary lines
            for line in reversed(lines):
                bacc_match = final_bacc_pattern.search(line)
                if bacc_match:
                    best_acc = float(bacc_match.group(1))
                    found_bacc = True
                    # We can optionally search for the stats line here too if needed
                    # For now, just break once bAcc is found
                    break

            # Optional: Extract detailed stats if needed (can search again or within the same loop)
            # for line in reversed(lines):
            #     stats_match = final_stats_pattern.search(line)
            #     if stats_match:
            #         best_many = float(stats_match.group(1))
            #         best_medium = float(stats_match.group(2))
            #         best_few = float(stats_match.group(3))
            #         break
            # --- End Modified Search --- #

            # --- MODIFIED: Return logic based on finding final_bacc --- #
            if found_bacc:
                # print(f"Successfully parsed final bAcc: {best_acc}%")
                # Return the primary metric. Add others if needed later.
                return {
                    "Best Acc": best_acc,
                    # Include these if extracted and needed by eoh/analysis
                    # "Best Many": best_many if best_many >= 0 else None,
                    # "Best Medium": best_medium if best_medium >= 0 else None,
                    # "Best Few": best_few if best_few >= 0 else None,
                }
            else:
                print(
                    f"Warning: Could not find the final '> best bAcc (test):' line in log: {log_file_path}")
                return None  # Indicate failure to find the required metric

        except FileNotFoundError:
            print(f"Error: Log file not found for parsing: {log_file_path}")
            return None
        except Exception as e:
            print(f"Error parsing log file '{log_file_path}': {e}")
            traceback.print_exc()
            return None

    def get_nicer_objective_values(self, result: dict) -> dict:
        """Provides a more human-readable version of the results."""
        # This is optional but helpful for displaying results
        # We can add more info parsed from logs if needed
        obj_values = result.get("objective_values", [])
        # Check if not default failure score
        if obj_values and obj_values[0] > -1e9:
            # Assuming the first objective value holds the 'Best Acc'
            return {"accuracy": f"{obj_values[0]:.2f}%"}
        else:
            return {"status": "Evaluation Failed or Timed Out"}
