"""Controller for prompt-only multitask OpenEvolve runs."""

from __future__ import annotations

import asyncio
import base64
import hashlib
import json
import logging
import os
import pickle
import random
import shutil
import time
import uuid
from concurrent.futures.process import BrokenProcessPool
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import asdict, dataclass, field, replace
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np

from openevolve.config import Config, PromptConfig
from openevolve.database import Program, ProgramDatabase
from openevolve.evaluator import Evaluator
from openevolve.evolution_trace import EvolutionTracer
from openevolve.llm.ensemble import LLMEnsemble
from openevolve.multitask.config import (
    MultitaskConfig,
    RelatedTaskConfig,
    TaskConfig,
    derive_task_config,
    load_base_task_config,
)
from openevolve.multitask.parallel_worker import (
    DedicatedTaskWorker,
    InitialProgramEvaluationRequest,
    TaskIterationRequest,
    TaskIterationWorkerResult,
    WorkerRngState,
)
from openevolve.prompt.sampler import PromptSampler
from openevolve.utils.code_utils import (
    apply_diff,
    apply_diff_blocks,
    extract_code_language,
    extract_diffs,
    format_diff_summary,
    parse_full_rewrite,
    split_diffs_by_target,
)
from openevolve.utils.format_utils import format_metrics_safe
from openevolve.utils.metrics_utils import get_fitness_score, safe_numeric_average
from openevolve.utils.wandb_logger import create_wandb_logger, flatten_scalars

logger = logging.getLogger(__name__)
_CURRENT_TASK_LOG_CONTEXT: ContextVar[str] = ContextVar(
    "openevolve_multitask_task_log_context", default="multitask"
)


class _TaskContextFilter(logging.Filter):
    """Inject the current multitask task label into log records."""

    def filter(self, record: logging.LogRecord) -> bool:
        if not getattr(record, "task_name", None):
            record.task_name = _CURRENT_TASK_LOG_CONTEXT.get()
        return True


class _SpecificTaskFilter(_TaskContextFilter):
    """Allow only records for a specific multitask task label."""

    def __init__(self, task_name: str):
        super().__init__()
        self._task_name = task_name

    def filter(self, record: logging.LogRecord) -> bool:
        super().filter(record)
        return record.task_name == self._task_name


def _serialize_state(value: Any) -> str:
    return base64.b64encode(pickle.dumps(value)).decode("utf-8")


def _deserialize_state(value: str) -> Any:
    return pickle.loads(base64.b64decode(value.encode("utf-8")))


def _normalize_json_compatible(value: Any) -> Any:
    """Recursively normalize values for config snapshots and resume validation."""
    if value is None or isinstance(value, (str, int, float, bool)):
        return value
    if isinstance(value, Path):
        return str(value)
    if isinstance(value, dict):
        return {str(key): _normalize_json_compatible(item) for key, item in value.items()}
    if isinstance(value, (list, tuple, set)):
        return [_normalize_json_compatible(item) for item in value]
    if callable(value):
        module = getattr(value, "__module__", None)
        qualname = getattr(value, "__qualname__", getattr(value, "__name__", repr(value)))
        return {"__callable__": f"{module}.{qualname}" if module else qualname}
    if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")):
        return _normalize_json_compatible(value.to_dict())
    return repr(value)


def _normalize_recent_child_fitness_history(values: Any) -> List[float]:
    """Best-effort normalization for checkpointed reward-history state."""
    history: List[float] = []
    for value in values or []:
        try:
            fitness = float(value)
        except (TypeError, ValueError):
            continue
        if np.isfinite(fitness):
            history.append(fitness)
    return history


@contextmanager
def temporary_env(env: Dict[str, str]):
    """Temporarily apply environment variables for a task."""
    previous_values: Dict[str, Optional[str]] = {}
    try:
        for key, value in env.items():
            previous_values[key] = os.environ.get(key)
            os.environ[key] = value
        yield
    finally:
        for key, previous_value in previous_values.items():
            if previous_value is None:
                os.environ.pop(key, None)
            else:
                os.environ[key] = previous_value


@dataclass
class TaskState:
    """Runtime state for a single task in a multitask run."""

    task_name: str
    initial_program_path: str
    initial_program_code: str
    evaluation_file: str
    config: Config
    database: ProgramDatabase
    evaluator: Evaluator
    prompt_sampler: PromptSampler
    evaluator_prompt_sampler: PromptSampler
    llm_ensemble: LLMEnsemble
    llm_evaluator_ensemble: LLMEnsemble
    output_dir: str
    file_extension: str
    env: Dict[str, str]
    related_tasks: List[RelatedTaskConfig]
    local_iteration: int = 0
    no_improve_steps: int = 0
    last_improvement_iteration: Optional[int] = None
    last_transfer_iteration: Optional[int] = None
    transfer_bandit_alpha: Dict[str, float] = field(default_factory=dict)
    transfer_bandit_beta: Dict[str, float] = field(default_factory=dict)
    transfer_bandit_pulls: Dict[str, int] = field(default_factory=dict)
    recent_child_fitness_history: List[float] = field(default_factory=list)
    checkpoint_metadata: Dict[str, Any] = field(default_factory=dict)
    random_state: Any = None
    numpy_random_state: Any = None
    evolution_tracer: Optional[EvolutionTracer] = None


@dataclass
class TaskIterationResult:
    """Outcome of a single multitask step."""

    task_name: str
    local_iteration: int
    success: bool
    child_program: Optional[Program] = None
    failure_reason: Optional[str] = None
    generation_time_sec: Optional[float] = None
    evaluation_time_sec: Optional[float] = None
    iteration_time_sec: Optional[float] = None
    foreign_inspiration_sources: List[str] = field(default_factory=list)
    foreign_transfer_trigger_reason: Optional[str] = None
    chosen_transfer_arm: Optional[str] = None


@dataclass(frozen=True)
class FrozenTaskTransferState:
    """Immutable view of transfer-related task state for one prepared wave."""

    local_iteration: int
    no_improve_steps: int
    last_improvement_iteration: Optional[int]
    last_transfer_iteration: Optional[int]
    transfer_bandit_alpha: Dict[str, float]
    transfer_bandit_beta: Dict[str, float]
    transfer_bandit_pulls: Dict[str, int]


@dataclass
class ForeignTransferDecision:
    """Decision for whether and how to include foreign inspirations this step."""

    trigger_mode: str
    trigger_reason: Optional[str] = None
    chosen_transfer_arm: Optional[str] = None
    foreign_inspirations: List[Dict[str, Any]] = field(default_factory=list)


@dataclass(frozen=True)
class TaskProgressUpdate:
    """Committed-step progress summary used for logging and reward updates."""

    delta_best: float
    improvement_detected: bool
    reward_for_chosen_arm: Optional[int]
    reward_mode: Optional[str] = None
    child_fitness_for_reward: Optional[float] = None
    reward_baseline_fitness: Optional[float] = None


class MultiTaskOpenEvolve:
    """Run multiple related OpenEvolve tasks with prompt-only cross-task transfer."""

    def __init__(self, multitask_config: MultitaskConfig):
        self.multitask_config = multitask_config
        self.output_dir = multitask_config.output_dir
        os.makedirs(self.output_dir, exist_ok=True)

        self.base_config = load_base_task_config(multitask_config)
        self._resume_validation_payload = self._build_resume_validation_payload()
        self._resume_validation_hash = self._hash_resume_validation_payload(
            self._resume_validation_payload
        )
        self._log_timestamp = time.strftime("%Y%m%d_%H%M%S")
        self.root_log_dir: Optional[str] = None
        self._task_log_paths: Dict[str, str] = {}
        self._setup_logging()

        self.tasks: List[TaskState] = []
        self.task_by_name: Dict[str, TaskState] = {}
        self._claimed_task_output_paths: Dict[str, Tuple[str, str]] = {}
        self.completed_global_iterations = 0
        self.next_task_index = 0
        self._last_checkpoint_iteration = 0
        self._scheduler_counts: Dict[str, int] = {}
        self._task_best_fitness: Dict[str, Optional[float]] = {}
        self._run_started_at: Optional[float] = None
        self.wandb_logger = create_wandb_logger(
            self.multitask_config.wandb or self.base_config.wandb,
            self.output_dir,
        )

        self._initialize_tasks()

    def _setup_logging(self) -> None:
        log_dir = self._resolve_root_output_path(
            configured_path=self.base_config.log_dir,
            default_path="logs",
        )
        self.root_log_dir = log_dir
        os.makedirs(log_dir, exist_ok=True)

        root_logger = logging.getLogger()
        root_logger.setLevel(getattr(logging, self.base_config.log_level))
        context_filter = _TaskContextFilter()

        log_file = os.path.join(log_dir, f"openevolve_multitask_{self._log_timestamp}.log")
        file_handler = logging.FileHandler(log_file)
        file_handler.addFilter(context_filter)
        file_handler.setFormatter(
            logging.Formatter(
                "%(asctime)s - %(name)s - %(levelname)s - [%(task_name)s] %(message)s"
            )
        )
        root_logger.addHandler(file_handler)

        console_handler = logging.StreamHandler()
        console_handler.addFilter(context_filter)
        console_handler.setFormatter(
            logging.Formatter("%(asctime)s - %(levelname)s - [%(task_name)s] %(message)s")
        )
        root_logger.addHandler(console_handler)

        logger.info("Logging multitask run to %s", log_file)

    def _ensure_task_log_handler(self, task_name: str, log_dir: str) -> str:
        """Attach a task-specific file handler for records from one task."""
        existing_path = self._task_log_paths.get(task_name)
        if existing_path:
            return existing_path

        os.makedirs(log_dir, exist_ok=True)
        task_log_path = os.path.join(
            log_dir, f"openevolve_task_{task_name}_{self._log_timestamp}.log"
        )
        handler = logging.FileHandler(task_log_path)
        handler.addFilter(_SpecificTaskFilter(task_name))
        handler.setFormatter(
            logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - [%(task_name)s] %(message)s")
        )
        logging.getLogger().addHandler(handler)
        self._task_log_paths[task_name] = task_log_path
        return task_log_path

    def _parallel_worker_log_path(self, task_name: str, log_dir: str) -> str:
        """Return the file path for one task's dedicated parallel-worker log."""
        os.makedirs(log_dir, exist_ok=True)
        return os.path.join(
            log_dir, f"openevolve_parallel_worker_{task_name}_{self._log_timestamp}.log"
        )

    @contextmanager
    def _activate_task_logging(self, task_name: str):
        token = _CURRENT_TASK_LOG_CONTEXT.set(task_name)
        try:
            yield
        finally:
            _CURRENT_TASK_LOG_CONTEXT.reset(token)

    def _resolve_root_output_path(
        self, configured_path: Optional[str], default_path: str
    ) -> str:
        """Resolve a multitask-level runtime output path under the run output dir."""
        if configured_path:
            path = Path(configured_path).expanduser()
            if not path.is_absolute():
                path = Path(self.output_dir) / path
        else:
            path = Path(self.output_dir) / default_path

        return str(path.resolve())

    def _build_resume_validation_payload(self) -> Dict[str, Any]:
        """Capture the config structure that must remain stable across resume."""
        return {
            "base_config_path": self.multitask_config.base_config_path,
            "execution_mode": self.multitask_config.execution_mode,
            "scheduler": asdict(self.multitask_config.scheduler),
            "foreign_inspirations": self._normalize_foreign_inspirations_for_resume(
                asdict(self.multitask_config.foreign_inspirations)
            ),
            "base_config": self._normalize_base_config_for_resume(self.base_config.to_dict()),
            "tasks": [
                {
                    "name": task.name,
                    "initial_program": task.initial_program,
                    "evaluation_file": task.evaluation_file,
                    "env": dict(sorted(task.env.items())),
                    "config_overrides": _normalize_json_compatible(task.config_overrides),
                    "related_tasks": [
                        {
                            "source_task": related_task.source_task,
                            "enabled": related_task.enabled,
                            "prompt_context": related_task.prompt_context,
                        }
                        for related_task in task.related_tasks
                    ],
                }
                for task in self.multitask_config.tasks
            ],
        }

    def _normalize_base_config_for_resume(self, base_config: Dict[str, Any]) -> Dict[str, Any]:
        """Drop runtime-only fields so resume checks focus on behavioral config changes."""
        normalized = _normalize_json_compatible(base_config)
        normalized.pop("max_iterations", None)
        normalized.pop("checkpoint_interval", None)
        normalized.pop("log_dir", None)

        database_config = normalized.get("database")
        if isinstance(database_config, dict):
            database_config.pop("db_path", None)
            database_config.pop("artifacts_base_path", None)

        evolution_trace_config = normalized.get("evolution_trace")
        if isinstance(evolution_trace_config, dict):
            evolution_trace_config.pop("output_path", None)

        return normalized

    def _normalize_foreign_inspirations_for_resume(
        self, foreign_inspirations: Any
    ) -> Dict[str, Any]:
        """Canonicalize foreign-inspiration config snapshots across schema versions."""
        defaults = asdict(self.multitask_config.foreign_inspirations.__class__())
        normalized = dict(defaults)
        raw = _normalize_json_compatible(foreign_inspirations or {})
        if not isinstance(raw, dict):
            return normalized

        if "min_improvement" in raw and "min_best_fitness_improvement" not in raw:
            raw["min_best_fitness_improvement"] = raw.pop("min_improvement")
        else:
            raw.pop("min_improvement", None)

        normalized.update(raw)
        return normalized

    @staticmethod
    def _hash_resume_validation_payload(payload: Dict[str, Any]) -> str:
        serialized = json.dumps(payload, sort_keys=True, separators=(",", ":"))
        return hashlib.sha256(serialized.encode("utf-8")).hexdigest()

    def _build_resume_validation_payload_from_snapshot(
        self, snapshot: Dict[str, Any]
    ) -> Optional[Dict[str, Any]]:
        validation = snapshot.get("resume_validation")
        if isinstance(validation, dict):
            payload = validation.get("payload")
            if isinstance(payload, dict):
                normalized_payload = dict(payload)
                normalized_payload["foreign_inspirations"] = (
                    self._normalize_foreign_inspirations_for_resume(
                        payload.get("foreign_inspirations") or {}
                    )
                )
                return normalized_payload

        saved_config = snapshot.get("multitask_config")
        if not isinstance(saved_config, dict):
            return None

        tasks: List[Dict[str, Any]] = []
        for task in saved_config.get("tasks") or []:
            if not isinstance(task, dict):
                continue
            related_tasks = []
            for related_task in task.get("related_tasks") or []:
                if not isinstance(related_task, dict):
                    continue
                related_tasks.append(
                    {
                        "source_task": related_task.get("source_task"),
                        "enabled": related_task.get("enabled", True),
                        "prompt_context": related_task.get("prompt_context"),
                    }
                )
            tasks.append(
                {
                    "name": task.get("name"),
                    "initial_program": task.get("initial_program"),
                    "evaluation_file": task.get("evaluation_file"),
                    "env": dict(sorted((task.get("env") or {}).items())),
                    "config_overrides": task.get("config_overrides") or {},
                    "related_tasks": related_tasks,
                }
            )

        base_config = saved_config.get("base_config") or {}
        if not isinstance(base_config, dict):
            base_config = {}

        return {
            "base_config_path": saved_config.get("base_config_path"),
            "execution_mode": saved_config.get("execution_mode", "sequential_round_robin"),
            "scheduler": saved_config.get("scheduler") or {},
            "foreign_inspirations": self._normalize_foreign_inspirations_for_resume(
                saved_config.get("foreign_inspirations") or {}
            ),
            "base_config": self._normalize_base_config_for_resume(base_config),
            "tasks": tasks,
        }

    def _describe_resume_validation_mismatch(
        self, saved_payload: Dict[str, Any], current_payload: Dict[str, Any]
    ) -> str:
        mismatches: List[str] = []

        saved_tasks = saved_payload.get("tasks") or []
        current_tasks = current_payload.get("tasks") or []
        saved_task_names = [task.get("name") for task in saved_tasks]
        current_task_names = [task.get("name") for task in current_tasks]
        if saved_task_names != current_task_names:
            mismatches.append(
                f"task order/names changed (saved={saved_task_names}, current={current_task_names})"
            )

        saved_tasks_by_name = {
            task.get("name"): task for task in saved_tasks if isinstance(task, dict)
        }
        current_tasks_by_name = {
            task.get("name"): task for task in current_tasks if isinstance(task, dict)
        }
        for task_name in sorted(set(saved_tasks_by_name) & set(current_tasks_by_name)):
            saved_task = saved_tasks_by_name[task_name]
            current_task = current_tasks_by_name[task_name]
            for field_name in (
                "initial_program",
                "evaluation_file",
                "env",
                "related_tasks",
                "config_overrides",
            ):
                if saved_task.get(field_name) != current_task.get(field_name):
                    mismatches.append(f"task '{task_name}' {field_name} changed")

        if saved_payload.get("base_config_path") != current_payload.get("base_config_path"):
            mismatches.append("base_config_path changed")
        if saved_payload.get("execution_mode") != current_payload.get("execution_mode"):
            mismatches.append("execution_mode changed")
        if saved_payload.get("scheduler") != current_payload.get("scheduler"):
            mismatches.append("scheduler changed")
        if saved_payload.get("foreign_inspirations") != current_payload.get(
            "foreign_inspirations"
        ):
            mismatches.append("foreign_inspirations changed")
        if saved_payload.get("base_config") != current_payload.get("base_config"):
            mismatches.append("base_config changed")

        if not mismatches:
            mismatches.append("resume validation hash changed")

        summary = "; ".join(mismatches[:5])
        if len(mismatches) > 5:
            summary += f"; and {len(mismatches) - 5} more differences"
        return (
            "Checkpoint config snapshot does not match the current multitask config: "
            f"{summary}"
        )

    def _validate_checkpoint_config_snapshot(
        self, checkpoint_root: Path, force_resume: bool = False
    ) -> None:
        snapshot_path = checkpoint_root / "multitask_config_snapshot.json"
        if not snapshot_path.exists():
            logger.warning(
                "Checkpoint config snapshot not found at %s; skipping resume config validation",
                snapshot_path,
            )
            return

        with open(snapshot_path, "r") as handle:
            snapshot = json.load(handle)

        saved_payload = self._build_resume_validation_payload_from_snapshot(snapshot)
        if saved_payload is None:
            logger.warning(
                "Checkpoint config snapshot at %s is missing validation data; skipping resume config validation",
                snapshot_path,
            )
            return

        saved_hash = self._hash_resume_validation_payload(saved_payload)
        if saved_hash == self._resume_validation_hash:
            return

        message = self._describe_resume_validation_mismatch(
            saved_payload=saved_payload,
            current_payload=self._resume_validation_payload,
        )
        if force_resume:
            logger.warning("%s. Proceeding because force_resume=True.", message)
            return

        raise ValueError(f"{message}. Pass force_resume=True or --force-resume to override.")

    def _initialize_tasks(self) -> None:
        for task_config in self.multitask_config.tasks:
            task_state = self._create_task_state(task_config)
            self.tasks.append(task_state)
            self.task_by_name[task_state.task_name] = task_state
        self._scheduler_counts = {task.task_name: 0 for task in self.tasks}
        self._task_best_fitness = {task.task_name: None for task in self.tasks}

    def _create_task_state(self, task_config: TaskConfig) -> TaskState:
        task_output_dir = os.path.join(
            self.output_dir, task_config.output_subdir or task_config.name
        )
        os.makedirs(task_output_dir, exist_ok=True)

        config = derive_task_config(
            base_config=self.base_config,
            overrides=task_config.config_overrides,
            config_dir=Path(self.multitask_config.config_dir or "."),
        )
        self._configure_task_output_paths(
            task_name=task_config.name,
            task_output_dir=task_output_dir,
            config=config,
        )
        config.database.db_path = None

        initial_program_path = task_config.initial_program
        with open(initial_program_path, "r") as handle:
            initial_program_code = handle.read()

        file_extension = Path(initial_program_path).suffix or ".py"
        if not config.language:
            config.language = extract_code_language(initial_program_code)
        if not getattr(config, "file_suffix", None) or config.file_suffix == ".py":
            config.file_suffix = file_extension

        previous_random_state = random.getstate()
        previous_numpy_state = np.random.get_state()

        try:
            self._ensure_task_log_handler(task_config.name, config.log_dir)
            with self._activate_task_logging(task_config.name):
                self._configure_task_random_seed(config)

                # Manual-mode LLM clients read the queue dir during construction.
                self._setup_manual_mode_queue(config=config, output_dir=task_output_dir)

                prompt_sampler = PromptSampler(config.prompt)
                evaluator_prompt_sampler = PromptSampler(config.prompt)
                evaluator_prompt_sampler.set_templates("evaluator_system_message")

                llm_ensemble = LLMEnsemble(config.llm.models)
                llm_evaluator_ensemble = LLMEnsemble(config.llm.evaluator_models)

                config.database.novelty_llm = llm_ensemble
                database = ProgramDatabase(config.database)

                with temporary_env(task_config.env):
                    evaluator = Evaluator(
                        config.evaluator,
                        task_config.evaluation_file,
                        llm_evaluator_ensemble,
                        evaluator_prompt_sampler,
                        database=database,
                        suffix=file_extension,
                    )

                evolution_tracer = self._create_evolution_tracer(config, task_output_dir)

                task_state = TaskState(
                    task_name=task_config.name,
                    initial_program_path=initial_program_path,
                    initial_program_code=initial_program_code,
                    evaluation_file=task_config.evaluation_file,
                    config=config,
                    database=database,
                    evaluator=evaluator,
                    prompt_sampler=prompt_sampler,
                    evaluator_prompt_sampler=evaluator_prompt_sampler,
                    llm_ensemble=llm_ensemble,
                    llm_evaluator_ensemble=llm_evaluator_ensemble,
                    output_dir=task_output_dir,
                    file_extension=file_extension,
                    env=task_config.env,
                    related_tasks=task_config.related_tasks,
                    random_state=random.getstate(),
                    numpy_random_state=np.random.get_state(),
                    evolution_tracer=evolution_tracer,
                )
        finally:
            random.setstate(previous_random_state)
            np.random.set_state(previous_numpy_state)

        logger.info(
            "Initialized multitask task '%s' with output dir %s",
            task_state.task_name,
            task_state.output_dir,
        )
        return task_state

    def _configure_task_output_paths(
        self, task_name: str, task_output_dir: str, config: Config
    ) -> None:
        """Resolve task runtime output paths under the task output dir and detect collisions."""
        config.log_dir = self._resolve_task_output_path(
            task_name=task_name,
            label="log_dir",
            task_output_dir=task_output_dir,
            configured_path=config.log_dir,
            default_path="logs",
        )
        config.database.artifacts_base_path = self._resolve_task_output_path(
            task_name=task_name,
            label="database.artifacts_base_path",
            task_output_dir=task_output_dir,
            configured_path=config.database.artifacts_base_path,
            default_path="artifacts",
        )

        if config.evolution_trace.enabled:
            config.evolution_trace.output_path = self._resolve_task_output_path(
                task_name=task_name,
                label="evolution_trace.output_path",
                task_output_dir=task_output_dir,
                configured_path=config.evolution_trace.output_path,
                default_path=f"evolution_trace.{config.evolution_trace.format}",
            )

    def _resolve_task_output_path(
        self,
        task_name: str,
        label: str,
        task_output_dir: str,
        configured_path: Optional[str],
        default_path: str,
    ) -> str:
        """Resolve a task-specific output path and ensure it is not shared across tasks."""
        if configured_path:
            path = Path(configured_path).expanduser()
            if not path.is_absolute():
                path = Path(task_output_dir) / path
        else:
            path = Path(task_output_dir) / default_path

        resolved_path = str(path.resolve())
        existing_claim = self._claimed_task_output_paths.get(resolved_path)
        if existing_claim and existing_claim[0] != task_name:
            claimed_task, claimed_label = existing_claim
            raise ValueError(
                f"Multitask output path collision: task '{task_name}' {label} resolves to "
                f"'{resolved_path}', already claimed by task '{claimed_task}' {claimed_label}"
            )

        self._claimed_task_output_paths[resolved_path] = (task_name, label)
        return resolved_path

    def _configure_task_random_seed(self, config: Config) -> None:
        """Apply OpenEvolve-style deterministic seeding to a task config."""
        if config.random_seed is None:
            return

        random.seed(config.random_seed)
        np.random.seed(config.random_seed)

        base_seed = str(config.random_seed).encode("utf-8")
        llm_seed = int(hashlib.md5(base_seed + b"llm").hexdigest()[:8], 16) % (2**31)

        config.llm.random_seed = llm_seed
        for model_cfg in config.llm.models:
            if getattr(model_cfg, "random_seed", None) is None:
                model_cfg.random_seed = llm_seed
        for model_cfg in config.llm.evaluator_models:
            if getattr(model_cfg, "random_seed", None) is None:
                model_cfg.random_seed = llm_seed

    def _setup_manual_mode_queue(self, config: Config, output_dir: str) -> None:
        if not bool(getattr(config.llm, "manual_mode", False)):
            return

        queue_dir = Path(output_dir).expanduser().resolve() / "manual_tasks_queue"
        if queue_dir.exists():
            shutil.rmtree(queue_dir)
        queue_dir.mkdir(parents=True, exist_ok=True)

        config.llm._manual_queue_dir = str(queue_dir)
        for model_cfg in config.llm.models:
            model_cfg._manual_queue_dir = str(queue_dir)
        for model_cfg in config.llm.evaluator_models:
            model_cfg._manual_queue_dir = str(queue_dir)

    def _create_evolution_tracer(
        self, config: Config, output_dir: str
    ) -> Optional[EvolutionTracer]:
        if not config.evolution_trace.enabled:
            return None

        trace_output_path = config.evolution_trace.output_path
        if not trace_output_path:
            trace_output_path = os.path.join(
                output_dir, f"evolution_trace.{config.evolution_trace.format}"
            )

        return EvolutionTracer(
            output_path=trace_output_path,
            format=config.evolution_trace.format,
            include_code=config.evolution_trace.include_code,
            include_prompts=config.evolution_trace.include_prompts,
            enabled=True,
            buffer_size=config.evolution_trace.buffer_size,
            compress=config.evolution_trace.compress,
        )

    @contextmanager
    def _activate_task_context(self, task_state: TaskState):
        previous_random_state = random.getstate()
        previous_numpy_state = np.random.get_state()
        log_token = _CURRENT_TASK_LOG_CONTEXT.set(task_state.task_name)
        try:
            if task_state.random_state is not None:
                random.setstate(task_state.random_state)
            if task_state.numpy_random_state is not None:
                np.random.set_state(task_state.numpy_random_state)
            with temporary_env(task_state.env):
                yield
        finally:
            task_state.random_state = random.getstate()
            task_state.numpy_random_state = np.random.get_state()
            random.setstate(previous_random_state)
            np.random.set_state(previous_numpy_state)
            _CURRENT_TASK_LOG_CONTEXT.reset(log_token)

    def _get_task_best_fitness(self, task_state: TaskState) -> Optional[float]:
        best_program = task_state.database.get_best_program()
        if best_program is None or not best_program.metrics:
            return None
        return get_fitness_score(
            best_program.metrics,
            task_state.database.config.feature_dimensions,
        )

    def _aggregate_best_fitness(self) -> Dict[str, Optional[float]]:
        task_best_values = [
            fitness for fitness in self._task_best_fitness.values() if fitness is not None
        ]
        if not task_best_values:
            return {
                "multitask/best_task_fitness_max": None,
                "multitask/best_task_fitness_mean": None,
            }

        return {
            "multitask/best_task_fitness_max": max(task_best_values),
            "multitask/best_task_fitness_mean": sum(task_best_values) / len(task_best_values),
        }

    def _iterations_since_improvement(
        self, task_state: TaskState | FrozenTaskTransferState
    ) -> int:
        return task_state.no_improve_steps

    def _get_enabled_related_tasks(self, task_state: TaskState) -> List[RelatedTaskConfig]:
        return [related_task for related_task in task_state.related_tasks if related_task.enabled]

    def _get_transfer_bandit_arm_names(self, task_state: TaskState) -> List[str]:
        arm_names = ["NONE"]
        seen = {"NONE"}
        for related_task in self._get_enabled_related_tasks(task_state):
            if related_task.source_task in seen:
                continue
            seen.add(related_task.source_task)
            arm_names.append(related_task.source_task)
        return arm_names

    def _ensure_transfer_bandit_state(self, task_state: TaskState, *, reset: bool = False) -> None:
        arm_names = self._get_transfer_bandit_arm_names(task_state)
        alpha = task_state.transfer_bandit_alpha if not reset else {}
        beta = task_state.transfer_bandit_beta if not reset else {}
        pulls = task_state.transfer_bandit_pulls if not reset else {}
        task_state.transfer_bandit_alpha = {
            arm_name: float(alpha.get(arm_name, 1.0)) for arm_name in arm_names
        }
        task_state.transfer_bandit_beta = {
            arm_name: float(beta.get(arm_name, 1.0)) for arm_name in arm_names
        }
        task_state.transfer_bandit_pulls = {
            arm_name: int(pulls.get(arm_name, 0)) for arm_name in arm_names
        }

    def _snapshot_task_transfer_states(self) -> Dict[str, FrozenTaskTransferState]:
        snapshots: Dict[str, FrozenTaskTransferState] = {}
        for task_state in self.tasks:
            self._ensure_transfer_bandit_state(task_state)
            snapshots[task_state.task_name] = FrozenTaskTransferState(
                local_iteration=task_state.local_iteration,
                no_improve_steps=task_state.no_improve_steps,
                last_improvement_iteration=task_state.last_improvement_iteration,
                last_transfer_iteration=task_state.last_transfer_iteration,
                transfer_bandit_alpha=dict(task_state.transfer_bandit_alpha),
                transfer_bandit_beta=dict(task_state.transfer_bandit_beta),
                transfer_bandit_pulls=dict(task_state.transfer_bandit_pulls),
            )
        return snapshots

    def _is_task_stagnating(self, task_state: TaskState | FrozenTaskTransferState) -> bool:
        patience = self.multitask_config.foreign_inspirations.stagnation_patience
        if patience <= 0:
            return False
        return task_state.no_improve_steps >= patience

    def _is_stagnation_trigger_ready(
        self, task_state: TaskState | FrozenTaskTransferState
    ) -> bool:
        cfg = self.multitask_config.foreign_inspirations
        if task_state.local_iteration < cfg.warmup_task_iterations:
            return False
        if not self._is_task_stagnating(task_state):
            return False
        if task_state.last_transfer_iteration is None:
            return True
        return (
            task_state.local_iteration - task_state.last_transfer_iteration
        ) >= cfg.transfer_cooldown

    def _get_foreign_transfer_trigger(
        self,
        task_state: TaskState,
        frozen_state: Optional[FrozenTaskTransferState] = None,
    ) -> Optional[str]:
        cfg = self.multitask_config.foreign_inspirations
        transfer_state = frozen_state or task_state
        if not cfg.enabled or cfg.max_related_tasks == 0:
            return None
        if transfer_state.local_iteration < cfg.warmup_task_iterations:
            return None
        if cfg.trigger_mode == "periodic":
            if (
                (transfer_state.local_iteration - cfg.warmup_task_iterations)
                % cfg.every_n_task_iterations
            ) == 0:
                return "periodic"
            return None
        if self._is_stagnation_trigger_ready(transfer_state):
            return cfg.trigger_mode
        return None

    def _sample_task_betavariate(self, task_state: TaskState, alpha: float, beta: float) -> float:
        rng = random.Random()
        if task_state.random_state is not None:
            rng.setstate(task_state.random_state)
        sample = rng.betavariate(alpha, beta)
        task_state.random_state = rng.getstate()
        return sample

    def _select_online_bandit_arm(
        self,
        task_state: TaskState,
        transfer_state: TaskState | FrozenTaskTransferState,
    ) -> str:
        cfg = self.multitask_config.foreign_inspirations
        arm_names = self._get_transfer_bandit_arm_names(task_state)
        pulls = transfer_state.transfer_bandit_pulls
        under_sampled_arms = [
            arm_name for arm_name in arm_names if pulls.get(arm_name, 0) < cfg.min_pulls_per_arm
        ]
        if under_sampled_arms:
            return under_sampled_arms[0]

        samples = {
            arm_name: self._sample_task_betavariate(
                task_state,
                transfer_state.transfer_bandit_alpha.get(arm_name, 1.0),
                transfer_state.transfer_bandit_beta.get(arm_name, 1.0),
            )
            for arm_name in arm_names
        }
        return max(arm_names, key=lambda arm_name: samples[arm_name])

    def _collect_foreign_inspirations(
        self,
        task_state: TaskState,
        *,
        allowed_source_tasks: Optional[set[str]] = None,
        max_related_tasks: Optional[int] = None,
    ) -> List[Dict[str, Any]]:
        cfg = self.multitask_config.foreign_inspirations
        selected_sources: List[Dict[str, Any]] = []
        selection_limit = cfg.max_related_tasks if max_related_tasks is None else max_related_tasks

        for related_task in task_state.related_tasks:
            if len(selected_sources) >= selection_limit:
                break
            if not related_task.enabled:
                continue
            if (
                allowed_source_tasks is not None
                and related_task.source_task not in allowed_source_tasks
            ):
                continue

            source_state = self.task_by_name[related_task.source_task]
            top_programs = source_state.database.get_top_programs(
                n=cfg.top_programs_per_related_task
            )
            if not top_programs:
                continue

            selected_sources.append(
                {
                    "source_task": source_state.task_name,
                    "prompt_context": (
                        related_task.prompt_context
                        if cfg.include_optional_relation_text
                        else None
                    ),
                    "include_scores": cfg.include_scores,
                    "include_code": cfg.include_code,
                    "programs": [program.to_dict() for program in top_programs],
                }
            )

        return selected_sources

    def _get_foreign_transfer_decision(
        self,
        task_state: TaskState,
        frozen_state: Optional[FrozenTaskTransferState] = None,
    ) -> ForeignTransferDecision:
        cfg = self.multitask_config.foreign_inspirations
        transfer_state = frozen_state or task_state
        self._ensure_transfer_bandit_state(task_state)
        trigger_reason = self._get_foreign_transfer_trigger(task_state, frozen_state=transfer_state)
        if trigger_reason is None:
            return ForeignTransferDecision(trigger_mode=cfg.trigger_mode)

        if cfg.trigger_mode in {"periodic", "stagnation"}:
            foreign_inspirations = self._collect_foreign_inspirations(task_state)
            if not foreign_inspirations:
                return ForeignTransferDecision(trigger_mode=cfg.trigger_mode)
            return ForeignTransferDecision(
                trigger_mode=cfg.trigger_mode,
                trigger_reason=trigger_reason,
                foreign_inspirations=foreign_inspirations,
            )

        chosen_arm = self._select_online_bandit_arm(task_state, transfer_state)
        if chosen_arm == "NONE":
            return ForeignTransferDecision(
                trigger_mode=cfg.trigger_mode,
                trigger_reason=trigger_reason,
                chosen_transfer_arm="NONE",
            )

        foreign_inspirations = self._collect_foreign_inspirations(
            task_state,
            allowed_source_tasks={chosen_arm},
            max_related_tasks=1,
        )
        if not foreign_inspirations:
            logger.warning(
                "Task '%s' selected online-bandit source arm '%s' but no foreign inspiration "
                "payload was available; skipping transfer and posterior update",
                task_state.task_name,
                chosen_arm,
            )
            return ForeignTransferDecision(
                trigger_mode=cfg.trigger_mode,
                trigger_reason=trigger_reason,
            )
        return ForeignTransferDecision(
            trigger_mode=cfg.trigger_mode,
            trigger_reason=trigger_reason,
            chosen_transfer_arm=chosen_arm,
            foreign_inspirations=foreign_inspirations,
        )

    def _compute_transfer_bandit_scores(self, task_state: TaskState) -> Dict[str, float]:
        self._ensure_transfer_bandit_state(task_state)
        return {
            arm_name: task_state.transfer_bandit_alpha[arm_name]
            / (
                task_state.transfer_bandit_alpha[arm_name]
                + task_state.transfer_bandit_beta[arm_name]
            )
            for arm_name in self._get_transfer_bandit_arm_names(task_state)
        }

    def _update_transfer_bandit_posterior(
        self,
        task_state: TaskState,
        *,
        chosen_arm: str,
        reward: int,
    ) -> None:
        self._ensure_transfer_bandit_state(task_state)
        cfg = self.multitask_config.foreign_inspirations
        task_state.transfer_bandit_alpha[chosen_arm] = (
            cfg.bandit_decay * task_state.transfer_bandit_alpha[chosen_arm] + reward
        )
        task_state.transfer_bandit_beta[chosen_arm] = (
            cfg.bandit_decay * task_state.transfer_bandit_beta[chosen_arm] + (1 - reward)
        )
        task_state.transfer_bandit_pulls[chosen_arm] += 1

    def _add_transfer_bandit_metrics(
        self, metrics: Dict[str, Any], task_prefix: str, task_state: TaskState
    ) -> None:
        if self.multitask_config.foreign_inspirations.trigger_mode != "online_bandit":
            return
        for arm_name, score in self._compute_transfer_bandit_scores(task_state).items():
            metric_suffix = "none" if arm_name == "NONE" else arm_name
            metrics[f"{task_prefix}/bandit_score_{metric_suffix}"] = score

    def _compute_improvement_outcome(
        self,
        *,
        previous_best: Optional[float],
        current_best: Optional[float],
    ) -> Tuple[float, bool]:
        if current_best is None:
            return 0.0, False
        if previous_best is None:
            return current_best, True

        delta_best = current_best - previous_best
        improvement_detected = (
            current_best
            > previous_best
            + self.multitask_config.foreign_inspirations.min_best_fitness_improvement
        )
        return delta_best, improvement_detected

    def _compute_sparse_online_bandit_reward(self, *, improvement_detected: bool) -> int:
        return 1 if improvement_detected else 0

    def _get_child_fitness_for_reward(
        self,
        *,
        task_state: TaskState,
        child_program: Optional[Program],
    ) -> Optional[float]:
        if child_program is None or not child_program.metrics:
            return None

        child_fitness = get_fitness_score(
            child_program.metrics,
            task_state.database.config.feature_dimensions,
        )
        if not np.isfinite(child_fitness):
            return None
        return float(child_fitness)

    def _trim_recent_child_fitness_history(self, task_state: TaskState) -> None:
        max_history = max(32, self.multitask_config.foreign_inspirations.reward_window)
        if len(task_state.recent_child_fitness_history) > max_history:
            task_state.recent_child_fitness_history = task_state.recent_child_fitness_history[
                -max_history:
            ]

    def _append_recent_child_fitness(
        self,
        *,
        task_state: TaskState,
        child_fitness: float,
    ) -> None:
        task_state.recent_child_fitness_history.append(float(child_fitness))
        self._trim_recent_child_fitness_history(task_state)

    def _get_recent_child_fitness_baseline(self, task_state: TaskState) -> Optional[float]:
        history = task_state.recent_child_fitness_history
        if not history:
            return None

        reward_window = self.multitask_config.foreign_inspirations.reward_window
        recent_window = history[-reward_window:]
        if not recent_window:
            return None
        return float(np.median(recent_window))

    def _compute_online_bandit_reward(
        self,
        *,
        task_state: TaskState,
        improvement_detected: bool,
        child_fitness: Optional[float],
    ) -> Tuple[int, Optional[float]]:
        sparse_reward = self._compute_sparse_online_bandit_reward(
            improvement_detected=improvement_detected
        )
        cfg = self.multitask_config.foreign_inspirations
        if cfg.reward_mode != "rich":
            return sparse_reward, None

        reward_baseline_fitness = self._get_recent_child_fitness_baseline(task_state)
        if reward_baseline_fitness is None:
            return sparse_reward, None
        if child_fitness is None:
            return 0, reward_baseline_fitness

        rich_reward = 1 if child_fitness > reward_baseline_fitness + cfg.reward_margin else 0
        return rich_reward, reward_baseline_fitness

    def _update_task_progress_state(
        self,
        *,
        task_state: TaskState,
        previous_best: Optional[float],
        current_best: Optional[float],
        local_iteration: int,
        foreign_transfer_used: bool,
        chosen_transfer_arm: Optional[str],
        child_program: Optional[Program] = None,
    ) -> TaskProgressUpdate:
        delta_best, improvement_detected = self._compute_improvement_outcome(
            previous_best=previous_best,
            current_best=current_best,
        )

        if improvement_detected:
            task_state.no_improve_steps = 0
            task_state.last_improvement_iteration = local_iteration
        else:
            task_state.no_improve_steps += 1

        reward_for_chosen_arm: Optional[int] = None
        reward_mode: Optional[str] = None
        child_fitness_for_reward: Optional[float] = None
        reward_baseline_fitness: Optional[float] = None
        if self.multitask_config.foreign_inspirations.trigger_mode == "online_bandit":
            reward_mode = self.multitask_config.foreign_inspirations.reward_mode
            child_fitness_for_reward = self._get_child_fitness_for_reward(
                task_state=task_state,
                child_program=child_program,
            )
            if chosen_transfer_arm is not None:
                reward_for_chosen_arm, reward_baseline_fitness = (
                    self._compute_online_bandit_reward(
                        task_state=task_state,
                        improvement_detected=improvement_detected,
                        child_fitness=child_fitness_for_reward,
                    )
                )
                self._update_transfer_bandit_posterior(
                    task_state,
                    chosen_arm=chosen_transfer_arm,
                    reward=reward_for_chosen_arm,
                )
            if child_fitness_for_reward is not None:
                self._append_recent_child_fitness(
                    task_state=task_state,
                    child_fitness=child_fitness_for_reward,
                )

        if foreign_transfer_used:
            task_state.last_transfer_iteration = local_iteration

        return TaskProgressUpdate(
            delta_best=delta_best,
            improvement_detected=improvement_detected,
            reward_for_chosen_arm=reward_for_chosen_arm,
            reward_mode=reward_mode,
            child_fitness_for_reward=child_fitness_for_reward,
            reward_baseline_fitness=reward_baseline_fitness,
        )

    def _initialize_task_progress_state(self, *, from_checkpoint: bool) -> None:
        for task_state in self.tasks:
            self._ensure_transfer_bandit_state(task_state, reset=not from_checkpoint)
            if not from_checkpoint:
                best_program = task_state.database.get_best_program()
                task_state.no_improve_steps = 0
                task_state.last_improvement_iteration = (
                    best_program.iteration_found if best_program else None
                )
                task_state.last_transfer_iteration = None
                task_state.recent_child_fitness_history = []
            self._trim_recent_child_fitness_history(task_state)
            self._task_best_fitness[task_state.task_name] = self._get_task_best_fitness(task_state)

    def _build_wandb_run_metadata(
        self,
        *,
        checkpoint_path: Optional[str],
        max_global_iterations: int,
    ) -> Dict[str, Any]:
        metadata = {
            "mode": "multitask",
            "output_dir": str(Path(self.output_dir).resolve()),
            "requested_global_iterations": max_global_iterations,
            "scheduler_strategy": self.multitask_config.scheduler.strategy,
            "task_names": [task.task_name for task in self.tasks],
            "evaluation_files": {
                task.task_name: str(Path(task.evaluation_file).resolve()) for task in self.tasks
            },
            "initial_program_paths": {
                task.task_name: str(Path(task.initial_program_path).resolve()) for task in self.tasks
            },
        }
        if checkpoint_path:
            metadata["checkpoint_path"] = str(Path(checkpoint_path).resolve())
        return metadata

    def _log_multitask_initial_state(self) -> None:
        metrics: Dict[str, Any] = {
            "multitask/global_iteration": self.completed_global_iterations,
            "multitask/num_active_tasks": len(self.tasks),
        }
        metrics.update(self._aggregate_best_fitness())

        for task_state in self.tasks:
            best_program = task_state.database.get_best_program()
            best_fitness = self._get_task_best_fitness(task_state)
            self._task_best_fitness[task_state.task_name] = best_fitness
            task_prefix = f"task/{task_state.task_name}"
            metrics[f"{task_prefix}/task_local_iteration"] = task_state.local_iteration
            metrics[f"{task_prefix}/best_fitness"] = best_fitness
            metrics[f"{task_prefix}/evaluation_success"] = 1 if best_program else 0
            metrics[f"{task_prefix}/no_improve_steps"] = task_state.no_improve_steps
            metrics[f"{task_prefix}/last_improvement_iteration"] = (
                task_state.last_improvement_iteration
                if task_state.last_improvement_iteration is not None
                else -1
            )
            metrics[f"{task_prefix}/last_transfer_iteration"] = (
                task_state.last_transfer_iteration
                if task_state.last_transfer_iteration is not None
                else -1
            )
            metrics[f"{task_prefix}/iterations_since_improvement"] = (
                self._iterations_since_improvement(task_state)
            )
            metrics[f"{task_prefix}/trigger_mode"] = self.multitask_config.foreign_inspirations.trigger_mode
            metrics[f"{task_prefix}/stagnating"] = 1 if self._is_task_stagnating(task_state) else 0
            metrics[f"{task_prefix}/stagnation_ready"] = (
                1 if self._is_stagnation_trigger_ready(task_state) else 0
            )
            self._add_transfer_bandit_metrics(metrics, task_prefix, task_state)
            if best_program is not None:
                metrics.update(
                    {
                        key: value
                        for key, value in flatten_scalars(best_program.metrics, prefix=task_prefix).items()
                        if key not in metrics
                    }
                )

        self.wandb_logger.log_metrics(metrics, step=self.completed_global_iterations)

    def _log_task_best_program_artifact(self, task_state: TaskState, program: Program) -> None:
        paths = self._save_task_best_program(task_state)
        if not paths:
            return

        self.wandb_logger.log_best_program_artifact(
            paths["code_path"],
            metadata={
                "task_name": task_state.task_name,
                "iteration": program.iteration_found,
                "fitness": self._get_task_best_fitness(task_state),
                "program_id": program.id,
                "metrics": flatten_scalars(program.metrics),
            },
            task_name=task_state.task_name,
        )

    def _log_multitask_step(
        self,
        *,
        global_iteration: int,
        task_state: TaskState,
        result: TaskIterationResult,
    ) -> None:
        task_name = task_state.task_name
        task_prefix = f"task/{task_name}"
        previous_best = self._task_best_fitness.get(task_name)
        current_task_best = self._get_task_best_fitness(task_state)
        if current_task_best is not None:
            self._task_best_fitness[task_name] = current_task_best

        foreign_transfer_used = bool(result.foreign_inspiration_sources)
        progress_update = self._update_task_progress_state(
            task_state=task_state,
            previous_best=previous_best,
            current_best=current_task_best,
            local_iteration=result.local_iteration,
            foreign_transfer_used=foreign_transfer_used,
            chosen_transfer_arm=result.chosen_transfer_arm,
            child_program=result.child_program,
        )
        iterations_since_improvement = self._iterations_since_improvement(task_state)
        stagnating = self._is_task_stagnating(task_state)
        stagnation_ready = self._is_stagnation_trigger_ready(task_state)
        transfer_trigger_reason = result.foreign_transfer_trigger_reason or ""

        current_fitness = None
        if result.child_program is not None:
            current_fitness = get_fitness_score(
                result.child_program.metrics,
                task_state.database.config.feature_dimensions,
            )

        metrics: Dict[str, Any] = {
            "multitask/global_iteration": global_iteration,
            "multitask/selected_task": task_name,
            "multitask/num_active_tasks": len(self.tasks),
            "multitask/current_fitness": current_fitness,
            "multitask/evaluation_success": 1 if result.success else 0,
            "multitask/trigger_mode": self.multitask_config.foreign_inspirations.trigger_mode,
            "multitask/foreign_inspirations_used": 1 if foreign_transfer_used else 0,
            "multitask/foreign_transfer_used": 1 if foreign_transfer_used else 0,
            "multitask/foreign_transfer_triggered": 1 if foreign_transfer_used else 0,
            "multitask/foreign_transfer_trigger_reason": transfer_trigger_reason,
            "multitask/chosen_transfer_arm": result.chosen_transfer_arm or "",
            "multitask/foreign_transfer_reward": (
                progress_update.reward_for_chosen_arm
                if progress_update.reward_for_chosen_arm is not None
                else -1
            ),
            "multitask/num_foreign_inspirations": len(result.foreign_inspiration_sources),
            "multitask/foreign_inspiration_sources": ",".join(result.foreign_inspiration_sources),
            "multitask/scheduler_count": self._scheduler_counts.get(task_name, 0),
            "task_name": task_name,
            "task_local_iteration": result.local_iteration,
            f"{task_prefix}/task_local_iteration": result.local_iteration,
            f"{task_prefix}/current_fitness": current_fitness,
            f"{task_prefix}/best_fitness": current_task_best,
            f"{task_prefix}/delta_best_fitness": progress_update.delta_best,
            f"{task_prefix}/evaluation_success": 1 if result.success else 0,
            f"{task_prefix}/iteration_time_sec": result.iteration_time_sec,
            f"{task_prefix}/generation_time_sec": result.generation_time_sec,
            f"{task_prefix}/evaluation_time_sec": result.evaluation_time_sec,
            f"{task_prefix}/failure_reason": result.failure_reason,
            f"{task_prefix}/no_improve_steps": task_state.no_improve_steps,
            f"{task_prefix}/trigger_mode": self.multitask_config.foreign_inspirations.trigger_mode,
            f"{task_prefix}/last_improvement_iteration": (
                task_state.last_improvement_iteration
                if task_state.last_improvement_iteration is not None
                else -1
            ),
            f"{task_prefix}/last_transfer_iteration": (
                task_state.last_transfer_iteration
                if task_state.last_transfer_iteration is not None
                else -1
            ),
            f"{task_prefix}/iterations_since_improvement": iterations_since_improvement,
            f"{task_prefix}/stagnating": 1 if stagnating else 0,
            f"{task_prefix}/stagnation_ready": 1 if stagnation_ready else 0,
            f"{task_prefix}/chosen_transfer_arm": result.chosen_transfer_arm or "",
            f"{task_prefix}/foreign_transfer_used": 1 if foreign_transfer_used else 0,
            f"{task_prefix}/foreign_transfer_reward": (
                progress_update.reward_for_chosen_arm
                if progress_update.reward_for_chosen_arm is not None
                else -1
            ),
            f"{task_prefix}/foreign_transfer_triggered": 1 if foreign_transfer_used else 0,
            f"{task_prefix}/foreign_transfer_trigger_reason": transfer_trigger_reason,
        }
        if progress_update.reward_mode is not None:
            metrics.update(
                {
                    "multitask/bandit_reward_mode": progress_update.reward_mode,
                    "multitask/bandit_reward_child_fitness": (
                        progress_update.child_fitness_for_reward
                    ),
                    "multitask/bandit_reward_baseline": (
                        progress_update.reward_baseline_fitness
                    ),
                    f"{task_prefix}/bandit_reward_mode": progress_update.reward_mode,
                    f"{task_prefix}/bandit_reward_child_fitness": (
                        progress_update.child_fitness_for_reward
                    ),
                    f"{task_prefix}/bandit_reward_baseline": (
                        progress_update.reward_baseline_fitness
                    ),
                }
            )
        self._add_transfer_bandit_metrics(metrics, task_prefix, task_state)
        metrics.update(self._aggregate_best_fitness())

        if result.child_program is not None:
            metrics.update(
                {
                    key: value
                    for key, value in flatten_scalars(
                        result.child_program.metrics,
                        prefix=task_prefix,
                    ).items()
                    if key not in metrics
                }
            )
        logger.info(
            "Task '%s' committed transfer state at iteration %d: mode=%s no_improve_steps=%d "
            "chosen_arm=%s foreign_used=%s reward_mode=%s reward=%s child_fitness=%s "
            "reward_baseline=%s last_transfer_iteration=%s bandit_scores=%s",
            task_name,
            result.local_iteration,
            self.multitask_config.foreign_inspirations.trigger_mode,
            task_state.no_improve_steps,
            result.chosen_transfer_arm,
            foreign_transfer_used,
            progress_update.reward_mode,
            progress_update.reward_for_chosen_arm,
            progress_update.child_fitness_for_reward,
            progress_update.reward_baseline_fitness,
            task_state.last_transfer_iteration,
            self._compute_transfer_bandit_scores(task_state)
            if self.multitask_config.foreign_inspirations.trigger_mode == "online_bandit"
            else {},
        )
        self.wandb_logger.log_metrics(metrics, step=global_iteration)

        if (
            result.child_program is not None
            and task_state.database.best_program_id == result.child_program.id
        ):
            self._log_task_best_program_artifact(task_state, result.child_program)
            self.wandb_logger.update_summary(
                {
                    f"{task_prefix}/best_fitness": current_task_best,
                    f"{task_prefix}/best_iteration": result.child_program.iteration_found,
                }
            )

    async def _ensure_initial_program(self, task_state: TaskState) -> None:
        if task_state.database.programs:
            return

        initial_program_id = str(uuid.uuid4())
        with self._activate_task_context(task_state):
            initial_metrics = await task_state.evaluator.evaluate_program(
                task_state.initial_program_code, initial_program_id
            )

        initial_program = Program(
            id=initial_program_id,
            code=task_state.initial_program_code,
            changes_description=task_state.config.prompt.initial_changes_description,
            language=task_state.config.language,
            metrics=initial_metrics,
            iteration_found=0,
        )
        task_state.database.add(initial_program, iteration=0)

        artifacts = task_state.evaluator.get_pending_artifacts(initial_program_id)
        if artifacts:
            task_state.database.store_artifacts(initial_program_id, artifacts)

        logger.info(
            "Initialized task '%s' with initial metrics: %s",
            task_state.task_name,
            format_metrics_safe(initial_metrics),
        )

    def _should_include_foreign_inspirations(self, task_state: TaskState) -> bool:
        if self.multitask_config.foreign_inspirations.trigger_mode != "online_bandit":
            return self._get_foreign_transfer_trigger(task_state) is not None
        return bool(self._get_foreign_transfer_decision(task_state).foreign_inspirations)

    def _select_foreign_inspirations_with_reason(
        self,
        task_state: TaskState,
        frozen_state: Optional[FrozenTaskTransferState] = None,
    ) -> Tuple[List[Dict[str, Any]], Optional[str]]:
        decision = self._get_foreign_transfer_decision(task_state, frozen_state=frozen_state)
        return decision.foreign_inspirations, decision.trigger_reason

    def _select_foreign_inspirations(self, task_state: TaskState) -> List[Dict[str, Any]]:
        foreign_inspirations, _trigger_reason = self._select_foreign_inspirations_with_reason(
            task_state
        )
        return foreign_inspirations

    @staticmethod
    def _resolve_num_local_inspirations(prompt_config: PromptConfig) -> int:
        if prompt_config.num_local_inspirations is not None:
            return prompt_config.num_local_inspirations
        return prompt_config.num_top_programs

    def _get_effective_prompt_config(
        self,
        task_state: TaskState,
        *,
        foreign_inspirations: List[Dict[str, Any]],
    ) -> PromptConfig:
        effective_prompt_config = replace(task_state.config.prompt)
        prompt_overrides = self.multitask_config.foreign_inspirations.prompt_overrides
        if not foreign_inspirations or prompt_overrides is None:
            return effective_prompt_config

        for field_name in (
            "num_top_programs",
            "num_diverse_programs",
            "num_local_inspirations",
        ):
            override_value = getattr(prompt_overrides, field_name)
            if override_value is not None:
                setattr(effective_prompt_config, field_name, override_value)

        return effective_prompt_config

    def _get_iteration_prompt_sampler(
        self,
        task_state: TaskState,
        prompt_config: PromptConfig,
    ) -> PromptSampler:
        if (
            prompt_config.num_top_programs == task_state.config.prompt.num_top_programs
            and prompt_config.num_diverse_programs
            == task_state.config.prompt.num_diverse_programs
        ):
            return task_state.prompt_sampler
        return PromptSampler(prompt_config)

    def _prepare_task_iteration_request(
        self,
        task_state: TaskState,
        frozen_transfer_states: Optional[Dict[str, FrozenTaskTransferState]] = None,
    ) -> TaskIterationRequest:
        """Prepare one serializable task attempt from the current frozen task state."""
        iteration_number = task_state.local_iteration + 1
        target_island = (iteration_number - 1) % task_state.config.database.num_islands
        decision = self._get_foreign_transfer_decision(
            task_state,
            frozen_state=(
                frozen_transfer_states.get(task_state.task_name)
                if frozen_transfer_states is not None
                else None
            ),
        )
        effective_prompt_config = self._get_effective_prompt_config(
            task_state,
            foreign_inspirations=decision.foreign_inspirations,
        )
        num_local_inspirations = self._resolve_num_local_inspirations(effective_prompt_config)

        with self._activate_task_context(task_state):
            parent, inspirations = task_state.database.sample_from_island(
                island_id=target_island,
                num_inspirations=num_local_inspirations,
            )

        island_programs = [
            task_state.database.programs[program_id]
            for program_id in task_state.database.islands[target_island]
            if program_id in task_state.database.programs
        ]
        island_programs.sort(
            key=lambda program: get_fitness_score(
                program.metrics, task_state.database.config.feature_dimensions
            ),
            reverse=True,
        )

        programs_for_prompt = island_programs[
            : effective_prompt_config.num_top_programs + effective_prompt_config.num_diverse_programs
        ]
        best_programs_only = island_programs[: effective_prompt_config.num_top_programs]

        return TaskIterationRequest(
            task_name=task_state.task_name,
            local_iteration=iteration_number,
            target_island=target_island,
            parent_program=parent.to_dict(),
            inspirations=[program.to_dict() for program in inspirations],
            previous_programs=[program.to_dict() for program in best_programs_only],
            top_programs=[program.to_dict() for program in programs_for_prompt],
            parent_artifacts=task_state.database.get_artifacts(parent.id),
            foreign_inspirations=decision.foreign_inspirations,
            feature_dimensions=list(task_state.database.config.feature_dimensions),
            rng_state=self._build_worker_rng_state(task_state),
            foreign_transfer_trigger_reason=decision.trigger_reason,
            chosen_transfer_arm=decision.chosen_transfer_arm,
            effective_prompt_config=(
                asdict(effective_prompt_config)
                if effective_prompt_config != task_state.config.prompt
                else None
            ),
        )

    def _build_worker_rng_state(self, task_state: TaskState) -> WorkerRngState:
        """Capture the task RNG state owned by the main process for worker execution."""
        return WorkerRngState(
            python_random_state=task_state.random_state,
            numpy_random_state=task_state.numpy_random_state,
            llm_random_state=task_state.llm_ensemble.random_state.getstate(),
            evaluator_llm_random_state=task_state.llm_evaluator_ensemble.random_state.getstate(),
        )

    def _apply_worker_rng_state(self, task_state: TaskState, rng_state: WorkerRngState) -> None:
        """Restore authoritative worker-updated RNG state back into main-process task state."""
        task_state.random_state = rng_state.python_random_state
        task_state.numpy_random_state = rng_state.numpy_random_state
        task_state.llm_ensemble.random_state.setstate(rng_state.llm_random_state)
        task_state.llm_evaluator_ensemble.random_state.setstate(
            rng_state.evaluator_llm_random_state
        )

    async def _run_task_iteration(self, task_state: TaskState) -> TaskIterationResult:
        iteration_number = task_state.local_iteration + 1
        target_island = (iteration_number - 1) % task_state.config.database.num_islands
        decision = self._get_foreign_transfer_decision(task_state)
        foreign_inspirations = decision.foreign_inspirations
        foreign_transfer_trigger_reason = decision.trigger_reason
        effective_prompt_config = self._get_effective_prompt_config(
            task_state,
            foreign_inspirations=foreign_inspirations,
        )
        num_local_inspirations = self._resolve_num_local_inspirations(effective_prompt_config)
        prompt_sampler = self._get_iteration_prompt_sampler(task_state, effective_prompt_config)
        foreign_inspiration_sources = [
            source["source_task"] for source in foreign_inspirations if source.get("source_task")
        ]

        def fail(reason: str) -> TaskIterationResult:
            logger.warning(
                "Task '%s' iteration %d discarded: %s",
                task_state.task_name,
                iteration_number,
                reason,
            )
            return TaskIterationResult(
                task_name=task_state.task_name,
                local_iteration=iteration_number,
                success=False,
                failure_reason=reason,
                foreign_inspiration_sources=foreign_inspiration_sources,
                foreign_transfer_trigger_reason=foreign_transfer_trigger_reason,
                chosen_transfer_arm=decision.chosen_transfer_arm,
            )

        with self._activate_task_context(task_state):
            parent, inspirations = task_state.database.sample_from_island(
                island_id=target_island,
                num_inspirations=num_local_inspirations,
            )

            island_programs = [
                task_state.database.programs[program_id]
                for program_id in task_state.database.islands[target_island]
                if program_id in task_state.database.programs
            ]
            island_programs.sort(
                key=lambda program: get_fitness_score(
                    program.metrics, task_state.database.config.feature_dimensions
                ),
                reverse=True,
            )

            programs_for_prompt = island_programs[
                : effective_prompt_config.num_top_programs
                + effective_prompt_config.num_diverse_programs
            ]
            best_programs_only = island_programs[: effective_prompt_config.num_top_programs]

            if effective_prompt_config.programs_as_changes_description:
                parent_changes_desc = (
                    parent.changes_description
                    or effective_prompt_config.initial_changes_description
                )
                child_changes_desc = parent_changes_desc
            else:
                parent_changes_desc = None
                child_changes_desc = ""

            prompt = prompt_sampler.build_prompt(
                current_program=parent.code,
                parent_program=parent.code,
                program_metrics=parent.metrics,
                previous_programs=[program.to_dict() for program in best_programs_only],
                top_programs=[program.to_dict() for program in programs_for_prompt],
                inspirations=[program.to_dict() for program in inspirations],
                foreign_inspirations=foreign_inspirations,
                language=task_state.config.language,
                evolution_round=iteration_number,
                diff_based_evolution=task_state.config.diff_based_evolution,
                program_artifacts=task_state.database.get_artifacts(parent.id),
                feature_dimensions=task_state.database.config.feature_dimensions,
                current_changes_description=parent_changes_desc,
            )

            iteration_start = time.time()
            generation_start = time.time()
            llm_response = await task_state.llm_ensemble.generate_with_context(
                system_message=prompt["system"],
                messages=[{"role": "user", "content": prompt["user"]}],
            )
            generation_time = time.time() - generation_start
            if llm_response is None:
                task_state.local_iteration = iteration_number
                return fail("LLM returned no response")

            if task_state.config.diff_based_evolution:
                diff_blocks = extract_diffs(llm_response, task_state.config.diff_pattern)
                if not diff_blocks:
                    task_state.local_iteration = iteration_number
                    return fail("no valid diffs found in response")

                if task_state.config.prompt.programs_as_changes_description:
                    try:
                        code_blocks, desc_blocks, _unmatched = split_diffs_by_target(
                            diff_blocks,
                            code_text=parent.code,
                            changes_description_text=parent_changes_desc,
                        )
                    except Exception as exc:
                        task_state.local_iteration = iteration_number
                        return fail(str(exc))

                    child_code, _ = apply_diff_blocks(parent.code, code_blocks)
                    child_changes_desc, desc_applied = apply_diff_blocks(
                        parent_changes_desc, desc_blocks
                    )
                    if (
                        desc_applied == 0
                        or not child_changes_desc.strip()
                        or child_changes_desc.strip() == parent_changes_desc.strip()
                    ):
                        task_state.local_iteration = iteration_number
                        return fail("changes_description was not updated or is empty")

                    changes_summary = format_diff_summary(
                        code_blocks,
                        max_line_len=task_state.config.prompt.diff_summary_max_line_len,
                        max_lines=task_state.config.prompt.diff_summary_max_lines,
                    )
                else:
                    child_code = apply_diff(
                        parent.code, llm_response, task_state.config.diff_pattern
                    )
                    changes_summary = format_diff_summary(
                        diff_blocks,
                        max_line_len=task_state.config.prompt.diff_summary_max_line_len,
                        max_lines=task_state.config.prompt.diff_summary_max_lines,
                    )
            else:
                new_code = parse_full_rewrite(llm_response, task_state.config.language)
                if not new_code:
                    task_state.local_iteration = iteration_number
                    return fail("no valid rewritten program found in response")
                child_code = new_code
                changes_summary = "Full rewrite"

            if len(child_code) > task_state.config.max_code_length:
                task_state.local_iteration = iteration_number
                return fail(
                    f"generated code exceeds max length ({len(child_code)} > {task_state.config.max_code_length})"
                )

            child_program_id = str(uuid.uuid4())
            evaluation_start = time.time()
            child_metrics = await task_state.evaluator.evaluate_program(
                child_code, child_program_id
            )
            evaluation_time = time.time() - evaluation_start
            artifacts = task_state.evaluator.get_pending_artifacts(child_program_id)
            iteration_time = time.time() - iteration_start

            template_key = (
                "full_rewrite_user"
                if not task_state.config.diff_based_evolution
                else "diff_user"
            )
            metadata = {
                "changes": changes_summary,
                "parent_metrics": parent.metrics,
            }
            if foreign_inspiration_sources:
                metadata["foreign_inspiration_sources"] = foreign_inspiration_sources

            child_program = Program(
                id=child_program_id,
                code=child_code,
                changes_description=child_changes_desc,
                language=task_state.config.language,
                parent_id=parent.id,
                generation=parent.generation + 1,
                metrics=child_metrics,
                iteration_found=iteration_number,
                metadata=metadata,
            )

            task_state.database.add(
                child_program,
                iteration=iteration_number,
                target_island=target_island,
            )
            if artifacts:
                task_state.database.store_artifacts(child_program_id, artifacts)
            if task_state.database.config.log_prompts:
                task_state.database.log_prompt(
                    program_id=child_program.id,
                    template_key=template_key,
                    prompt=dict(prompt),
                    responses=[llm_response],
                )

            task_state.database.increment_island_generation(island_idx=target_island)
            if task_state.database.should_migrate():
                logger.info(
                    "Task '%s' performing intra-task migration at local iteration %d",
                    task_state.task_name,
                    iteration_number,
                )
                task_state.database.migrate_programs()

            if task_state.evolution_tracer:
                task_state.evolution_tracer.log_trace(
                    iteration=iteration_number,
                    parent_program=parent,
                    child_program=child_program,
                    prompt=prompt,
                    llm_response=llm_response,
                    artifacts=artifacts,
                    island_id=target_island,
                    metadata={
                        "iteration_time": iteration_time,
                        "changes": changes_summary,
                        "task_name": task_state.task_name,
                    },
                )

            task_state.local_iteration = iteration_number

        logger.info(
            "Task '%s' iteration %d completed in %.2fs: %s",
            task_state.task_name,
            iteration_number,
            iteration_time,
            format_metrics_safe(child_program.metrics),
        )

        if (
            "combined_score" not in child_program.metrics
            and not hasattr(task_state, "_warned_about_combined_score")
        ):
            setattr(task_state, "_warned_about_combined_score", True)
            logger.warning(
                "Task '%s' returned no combined_score; using safe numeric average %.4f for ranking guidance",
                task_state.task_name,
                safe_numeric_average(child_program.metrics),
            )

        if task_state.database.best_program_id == child_program.id:
            logger.info(
                "Task '%s' found a new best program at local iteration %d",
                task_state.task_name,
                iteration_number,
            )

        return TaskIterationResult(
            task_name=task_state.task_name,
            local_iteration=iteration_number,
            success=True,
            child_program=child_program,
            generation_time_sec=generation_time,
            evaluation_time_sec=evaluation_time,
            iteration_time_sec=iteration_time,
            foreign_inspiration_sources=foreign_inspiration_sources,
            foreign_transfer_trigger_reason=foreign_transfer_trigger_reason,
            chosen_transfer_arm=decision.chosen_transfer_arm,
        )

    def _save_task_best_program(
        self, task_state: TaskState, base_dir: Optional[str] = None
    ) -> Optional[Dict[str, str]]:
        best_program = task_state.database.get_best_program()
        if not best_program:
            return None

        best_dir = Path(base_dir) if base_dir else Path(task_state.output_dir) / "best"
        best_dir.mkdir(parents=True, exist_ok=True)

        code_path = best_dir / f"best_program{task_state.file_extension}"
        with open(code_path, "w") as handle:
            handle.write(best_program.code)

        info_path = best_dir / "best_program_info.json"
        with open(info_path, "w") as handle:
            json.dump(
                {
                    "id": best_program.id,
                    "generation": best_program.generation,
                    "iteration": best_program.iteration_found,
                    "timestamp": best_program.timestamp,
                    "parent_id": best_program.parent_id,
                    "metrics": best_program.metrics,
                    "language": best_program.language,
                    "task_name": task_state.task_name,
                    "saved_at": time.time(),
                },
                handle,
                indent=2,
            )
        return {"code_path": str(code_path), "info_path": str(info_path)}

    def _snapshot_task_artifacts(
        self, task_state: TaskState, task_checkpoint_path: Path
    ) -> None:
        """Copy large disk-backed artifacts into the task checkpoint and rewrite paths."""
        programs_dir = task_checkpoint_path / "programs"
        artifacts_root = task_checkpoint_path / "artifacts"

        for program in task_state.database.programs.values():
            if not program.artifact_dir:
                continue

            source_artifact_dir = Path(program.artifact_dir)
            if not source_artifact_dir.exists():
                logger.warning(
                    "Task '%s' program '%s' artifact_dir does not exist during checkpoint save: %s",
                    task_state.task_name,
                    program.id,
                    source_artifact_dir,
                )
                continue

            checkpoint_artifact_dir = artifacts_root / program.id
            if checkpoint_artifact_dir.exists():
                shutil.rmtree(checkpoint_artifact_dir)
            checkpoint_artifact_dir.parent.mkdir(parents=True, exist_ok=True)
            shutil.copytree(source_artifact_dir, checkpoint_artifact_dir)

            program_path = programs_dir / f"{program.id}.json"
            if not program_path.exists():
                logger.warning(
                    "Task '%s' checkpoint is missing saved program metadata for artifact remap: %s",
                    task_state.task_name,
                    program_path,
                )
                continue

            with open(program_path, "r") as handle:
                program_data = json.load(handle)
            program_data["artifact_dir"] = str(Path("artifacts") / program.id)
            with open(program_path, "w") as handle:
                json.dump(program_data, handle)

    def _save_checkpoint(self, global_iteration: int) -> str:
        checkpoint_root = Path(self.output_dir) / "checkpoints" / f"checkpoint_{global_iteration}"
        checkpoint_root.mkdir(parents=True, exist_ok=True)
        tasks_checkpoint_root = checkpoint_root / "tasks"
        tasks_checkpoint_root.mkdir(parents=True, exist_ok=True)

        task_checkpoints: Dict[str, str] = {}
        task_iterations: Dict[str, int] = {}
        task_random_states: Dict[str, str] = {}
        task_numpy_random_states: Dict[str, str] = {}
        llm_random_states: Dict[str, str] = {}
        evaluator_llm_random_states: Dict[str, str] = {}

        for task_state in self.tasks:
            task_checkpoint_relative_path = Path("tasks") / task_state.task_name
            task_checkpoint_path = checkpoint_root / task_checkpoint_relative_path
            task_checkpoint_path.mkdir(parents=True, exist_ok=True)
            task_state.database.save(str(task_checkpoint_path), iteration=task_state.local_iteration)
            self._snapshot_task_artifacts(task_state, task_checkpoint_path)
            self._save_task_best_program(task_state, base_dir=str(task_checkpoint_path))

            with open(task_checkpoint_path / "task_state.json", "w") as handle:
                json.dump(
                    {
                        "task_name": task_state.task_name,
                        "local_iteration": task_state.local_iteration,
                        "global_iteration": global_iteration,
                    },
                    handle,
                    indent=2,
                )

            task_state.checkpoint_metadata = {
                "path": str(task_checkpoint_path),
                "global_iteration": global_iteration,
                "local_iteration": task_state.local_iteration,
            }
            task_checkpoints[task_state.task_name] = str(task_checkpoint_relative_path)
            task_iterations[task_state.task_name] = task_state.local_iteration
            task_random_states[task_state.task_name] = _serialize_state(task_state.random_state)
            task_numpy_random_states[task_state.task_name] = _serialize_state(
                task_state.numpy_random_state
            )
            llm_random_states[task_state.task_name] = _serialize_state(
                task_state.llm_ensemble.random_state.getstate()
            )
            evaluator_llm_random_states[task_state.task_name] = _serialize_state(
                task_state.llm_evaluator_ensemble.random_state.getstate()
            )

        with open(checkpoint_root / "multitask_state.json", "w") as handle:
            json.dump(
                {
                    "completed_global_iterations": global_iteration,
                    "next_task_index": self.next_task_index,
                    "task_iterations": task_iterations,
                    "task_no_improve_steps": {
                        task.task_name: task.no_improve_steps for task in self.tasks
                    },
                    "task_last_improvement_iterations": {
                        task.task_name: task.last_improvement_iteration for task in self.tasks
                    },
                    "task_last_transfer_iterations": {
                        task.task_name: task.last_transfer_iteration for task in self.tasks
                    },
                    "task_transfer_bandit_alpha": {
                        task.task_name: task.transfer_bandit_alpha for task in self.tasks
                    },
                    "task_transfer_bandit_beta": {
                        task.task_name: task.transfer_bandit_beta for task in self.tasks
                    },
                    "task_transfer_bandit_pulls": {
                        task.task_name: task.transfer_bandit_pulls for task in self.tasks
                    },
                    "task_recent_child_fitness_history": {
                        task.task_name: task.recent_child_fitness_history for task in self.tasks
                    },
                    "task_checkpoints": task_checkpoints,
                    "task_random_states": task_random_states,
                    "task_numpy_random_states": task_numpy_random_states,
                    "llm_random_states": llm_random_states,
                    "evaluator_llm_random_states": evaluator_llm_random_states,
                },
                handle,
                indent=2,
            )

        with open(checkpoint_root / "multitask_config_snapshot.json", "w") as handle:
            json.dump(
                {
                    "saved_at": time.time(),
                    "multitask_config": _normalize_json_compatible(asdict(self.multitask_config)),
                    "resume_validation": {
                        "hash": self._resume_validation_hash,
                        "payload": self._resume_validation_payload,
                    },
                },
                handle,
                indent=2,
            )

        self._last_checkpoint_iteration = global_iteration
        logger.info("Saved multitask checkpoint at global iteration %d", global_iteration)
        self.wandb_logger.log_checkpoint_artifact(
            str(checkpoint_root),
            metadata={
                "global_iteration": global_iteration,
                "scheduler_counts": dict(self._scheduler_counts),
            },
        )
        return str(checkpoint_root)

    def _load_checkpoint(self, checkpoint_path: str, force_resume: bool = False) -> None:
        checkpoint_root = Path(checkpoint_path).resolve()
        self._validate_checkpoint_config_snapshot(
            checkpoint_root=checkpoint_root, force_resume=force_resume
        )
        state_path = checkpoint_root / "multitask_state.json"
        if not state_path.exists():
            raise FileNotFoundError(f"Multitask checkpoint metadata not found: {state_path}")

        with open(state_path, "r") as handle:
            state = json.load(handle)

        self.completed_global_iterations = state.get("completed_global_iterations", 0)
        self.next_task_index = state.get("next_task_index", 0)
        self._last_checkpoint_iteration = self.completed_global_iterations

        task_checkpoints = state.get("task_checkpoints", {})
        task_iterations = state.get("task_iterations", {})
        task_no_improve_steps = state.get("task_no_improve_steps", {})
        task_last_improvement_iterations = state.get("task_last_improvement_iterations", {})
        task_last_transfer_iterations = state.get("task_last_transfer_iterations", {})
        task_transfer_bandit_alpha = state.get("task_transfer_bandit_alpha", {})
        task_transfer_bandit_beta = state.get("task_transfer_bandit_beta", {})
        task_transfer_bandit_pulls = state.get("task_transfer_bandit_pulls", {})
        task_recent_child_fitness_history = state.get("task_recent_child_fitness_history", {})
        task_random_states = state.get("task_random_states", {})
        task_numpy_random_states = state.get("task_numpy_random_states", {})
        llm_random_states = state.get("llm_random_states", {})
        evaluator_llm_random_states = state.get("evaluator_llm_random_states", {})

        for task_state in self.tasks:
            task_checkpoint_path = task_checkpoints.get(task_state.task_name)
            if not task_checkpoint_path:
                raise ValueError(
                    f"Checkpoint is missing state for task '{task_state.task_name}'"
                )
            resolved_task_checkpoint_path = Path(task_checkpoint_path)
            if not resolved_task_checkpoint_path.is_absolute():
                resolved_task_checkpoint_path = checkpoint_root / resolved_task_checkpoint_path
            resolved_task_checkpoint_path = resolved_task_checkpoint_path.resolve()

            task_state.database.load(str(resolved_task_checkpoint_path))
            task_state.local_iteration = task_iterations.get(
                task_state.task_name, task_state.database.last_iteration
            )
            task_state.no_improve_steps = task_no_improve_steps.get(task_state.task_name, 0)
            task_state.last_improvement_iteration = task_last_improvement_iterations.get(
                task_state.task_name
            )
            task_state.last_transfer_iteration = task_last_transfer_iterations.get(
                task_state.task_name
            )
            task_state.transfer_bandit_alpha = {
                str(arm_name): float(value)
                for arm_name, value in (
                    task_transfer_bandit_alpha.get(task_state.task_name) or {}
                ).items()
            }
            task_state.transfer_bandit_beta = {
                str(arm_name): float(value)
                for arm_name, value in (
                    task_transfer_bandit_beta.get(task_state.task_name) or {}
                ).items()
            }
            task_state.transfer_bandit_pulls = {
                str(arm_name): int(value)
                for arm_name, value in (
                    task_transfer_bandit_pulls.get(task_state.task_name) or {}
                ).items()
            }
            task_state.recent_child_fitness_history = _normalize_recent_child_fitness_history(
                task_recent_child_fitness_history.get(task_state.task_name)
            )
            self._ensure_transfer_bandit_state(task_state)
            self._trim_recent_child_fitness_history(task_state)
            task_state.random_state = _deserialize_state(
                task_random_states[task_state.task_name]
            )
            task_state.numpy_random_state = _deserialize_state(
                task_numpy_random_states[task_state.task_name]
            )
            task_state.llm_ensemble.random_state.setstate(
                _deserialize_state(llm_random_states[task_state.task_name])
            )
            task_state.llm_evaluator_ensemble.random_state.setstate(
                _deserialize_state(evaluator_llm_random_states[task_state.task_name])
            )
            task_state.checkpoint_metadata = {
                "path": str(resolved_task_checkpoint_path),
                "global_iteration": self.completed_global_iterations,
                "local_iteration": task_state.local_iteration,
            }

        logger.info(
            "Loaded multitask checkpoint from %s at global iteration %d",
            checkpoint_root,
            self.completed_global_iterations,
        )

    def _write_summary(self, *, wall_clock_time_sec: Optional[float]) -> Dict[str, Optional[Program]]:
        summary: Dict[str, Any] = {
            "completed_global_iterations": self.completed_global_iterations,
            "wall_clock_time_sec": wall_clock_time_sec,
            "scheduler_counts": dict(self._scheduler_counts),
            "tasks": {},
        }
        best_programs: Dict[str, Optional[Program]] = {}

        for task_state in self.tasks:
            best_program = task_state.database.get_best_program()
            best_programs[task_state.task_name] = best_program
            best_fitness = self._get_task_best_fitness(task_state)
            summary["tasks"][task_state.task_name] = {
                "output_dir": task_state.output_dir,
                "local_iteration": task_state.local_iteration,
                "no_improve_steps": task_state.no_improve_steps,
                "last_improvement_iteration": task_state.last_improvement_iteration,
                "last_transfer_iteration": task_state.last_transfer_iteration,
                "iterations_since_improvement": self._iterations_since_improvement(task_state),
                "stagnating": self._is_task_stagnating(task_state),
                "stagnation_ready": self._is_stagnation_trigger_ready(task_state),
                "best_program_id": best_program.id if best_program else None,
                "best_iteration": best_program.iteration_found if best_program else None,
                "best_fitness": best_fitness,
                "best_metrics": best_program.metrics if best_program else None,
            }

        summary.update(self._aggregate_best_fitness())

        with open(Path(self.output_dir) / "summary.json", "w") as handle:
            json.dump(summary, handle, indent=2)

        return best_programs

    def _close_tracers(self) -> None:
        for task_state in self.tasks:
            if task_state.evolution_tracer:
                task_state.evolution_tracer.close()

    async def run(
        self,
        checkpoint_path: Optional[str] = None,
        max_global_iterations: Optional[int] = None,
        force_resume: bool = False,
        max_waves: Optional[int] = None,
    ) -> Dict[str, Optional[Program]]:
        if max_waves is not None:
            raise ValueError(
                "max_waves/--waves is only supported for "
                "execution_mode='parallel_synchronized_waves'"
            )
        total_iterations = max_global_iterations or self.multitask_config.max_global_iterations
        self._run_started_at = time.time()

        self.wandb_logger.init_run(
            run_mode="multitask",
            config_payload={
                "multitask": self.multitask_config,
                "base_config": self.base_config,
            },
            metadata=self._build_wandb_run_metadata(
                checkpoint_path=checkpoint_path,
                max_global_iterations=total_iterations,
            ),
            step_metric="multitask/global_iteration",
        )

        try:
            if checkpoint_path:
                self._load_checkpoint(checkpoint_path, force_resume=force_resume)
            else:
                for task_state in self.tasks:
                    await self._ensure_initial_program(task_state)

            self._initialize_task_progress_state(from_checkpoint=bool(checkpoint_path))

            self._log_multitask_initial_state()

            while self.completed_global_iterations < total_iterations:
                task_state = self.tasks[self.next_task_index]
                global_iteration = self.completed_global_iterations + 1
                self._scheduler_counts[task_state.task_name] = (
                    self._scheduler_counts.get(task_state.task_name, 0) + 1
                )
                result = await self._run_task_iteration(task_state)

                self.completed_global_iterations += 1
                self._log_multitask_step(
                    global_iteration=global_iteration,
                    task_state=task_state,
                    result=result,
                )
                self.next_task_index = (self.next_task_index + 1) % len(self.tasks)

                if (
                    self.completed_global_iterations > 0
                    and self.completed_global_iterations
                    % self.multitask_config.checkpoint_interval
                    == 0
                ):
                    self._save_checkpoint(self.completed_global_iterations)

            if (
                self.completed_global_iterations > 0
                and self._last_checkpoint_iteration != self.completed_global_iterations
            ):
                self._save_checkpoint(self.completed_global_iterations)

            for task_state in self.tasks:
                self._save_task_best_program(task_state)
                best_program = task_state.database.get_best_program()
                if best_program is not None:
                    self._log_task_best_program_artifact(task_state, best_program)

            wall_clock_time_sec = time.time() - self._run_started_at if self._run_started_at else None
            best_programs = self._write_summary(wall_clock_time_sec=wall_clock_time_sec)
            self.wandb_logger.update_summary(
                {
                    "completed_global_iterations": self.completed_global_iterations,
                    "wall_clock_time_sec": wall_clock_time_sec,
                    "scheduler_counts": dict(self._scheduler_counts),
                    **self._aggregate_best_fitness(),
                }
            )
            summary_path = Path(self.output_dir) / "summary.json"
            self.wandb_logger.log_file_artifact(
                str(summary_path),
                artifact_name="final-summary",
                artifact_type="summary",
                metadata={
                    "completed_global_iterations": self.completed_global_iterations,
                    "wall_clock_time_sec": wall_clock_time_sec,
                },
            )
            return best_programs
        finally:
            self._close_tracers()
            self.wandb_logger.finish()


class ParallelWaveMultiTaskOpenEvolve(MultiTaskOpenEvolve):
    """Synchronized-wave multitask controller with one dedicated worker per task."""

    def __init__(self, multitask_config: MultitaskConfig):
        self.completed_waves = 0
        self._last_checkpoint_wave = 0
        self._task_workers: Dict[str, DedicatedTaskWorker] = {}
        super().__init__(multitask_config)

    def _create_task_state(self, task_config: TaskConfig) -> TaskState:
        task_output_dir = os.path.join(
            self.output_dir, task_config.output_subdir or task_config.name
        )
        os.makedirs(task_output_dir, exist_ok=True)

        config = derive_task_config(
            base_config=self.base_config,
            overrides=task_config.config_overrides,
            config_dir=Path(self.multitask_config.config_dir or "."),
        )
        self._configure_task_output_paths(
            task_name=task_config.name,
            task_output_dir=task_output_dir,
            config=config,
        )
        config.database.db_path = None

        initial_program_path = task_config.initial_program
        with open(initial_program_path, "r") as handle:
            initial_program_code = handle.read()

        file_extension = Path(initial_program_path).suffix or ".py"
        if not config.language:
            config.language = extract_code_language(initial_program_code)
        if not getattr(config, "file_suffix", None) or config.file_suffix == ".py":
            config.file_suffix = file_extension

        previous_random_state = random.getstate()
        previous_numpy_state = np.random.get_state()

        try:
            self._ensure_task_log_handler(task_config.name, config.log_dir)
            with self._activate_task_logging(task_config.name):
                self._configure_task_random_seed(config)
                self._setup_manual_mode_queue(config=config, output_dir=task_output_dir)

                llm_ensemble = LLMEnsemble(config.llm.models)
                llm_evaluator_ensemble = LLMEnsemble(config.llm.evaluator_models)

                config.database.novelty_llm = llm_ensemble
                database = ProgramDatabase(config.database)
                evolution_tracer = self._create_evolution_tracer(config, task_output_dir)

                task_state = TaskState(
                    task_name=task_config.name,
                    initial_program_path=initial_program_path,
                    initial_program_code=initial_program_code,
                    evaluation_file=task_config.evaluation_file,
                    config=config,
                    database=database,
                    evaluator=None,
                    prompt_sampler=None,
                    evaluator_prompt_sampler=None,
                    llm_ensemble=llm_ensemble,
                    llm_evaluator_ensemble=llm_evaluator_ensemble,
                    output_dir=task_output_dir,
                    file_extension=file_extension,
                    env=task_config.env,
                    related_tasks=task_config.related_tasks,
                    random_state=random.getstate(),
                    numpy_random_state=np.random.get_state(),
                    evolution_tracer=evolution_tracer,
                )
        finally:
            random.setstate(previous_random_state)
            np.random.set_state(previous_numpy_state)

        logger.info(
            "Initialized synchronized-wave multitask task '%s' with output dir %s",
            task_state.task_name,
            task_state.output_dir,
        )
        return task_state

    def _create_task_worker(self, task_state: TaskState) -> DedicatedTaskWorker:
        worker_log_path = self._parallel_worker_log_path(
            task_name=task_state.task_name,
            log_dir=task_state.config.log_dir,
        )
        logger.info(
            "Parallel worker log for task '%s' will be written to %s",
            task_state.task_name,
            worker_log_path,
        )
        return DedicatedTaskWorker(
            task_name=task_state.task_name,
            config=task_state.config,
            evaluation_file=task_state.evaluation_file,
            task_env=task_state.env,
            worker_log_path=worker_log_path,
        )

    def _start_task_workers(self) -> None:
        for task_state in self.tasks:
            worker = self._task_workers.get(task_state.task_name)
            if worker is None:
                worker = self._create_task_worker(task_state)
                self._task_workers[task_state.task_name] = worker
            worker.start()

    def _stop_task_workers(self) -> None:
        for worker in self._task_workers.values():
            worker.stop()
        self._task_workers.clear()

    async def _ensure_initial_program(self, task_state: TaskState) -> None:
        if task_state.database.programs:
            return

        initial_program_id = str(uuid.uuid4())
        worker = self._task_workers.get(task_state.task_name)
        if worker is None:
            raise RuntimeError(f"No dedicated worker initialized for task '{task_state.task_name}'")

        future = worker.submit_initial_program(
            InitialProgramEvaluationRequest(
                program_id=initial_program_id,
                code=task_state.initial_program_code,
                rng_state=self._build_worker_rng_state(task_state),
            )
        )

        try:
            result = await asyncio.wrap_future(future)
        except Exception as exc:
            raise RuntimeError(
                f"Failed to evaluate initial program for task '{task_state.task_name}'"
            ) from exc

        self._apply_worker_rng_state(task_state, result.rng_state)

        initial_program = Program(
            id=initial_program_id,
            code=task_state.initial_program_code,
            changes_description=task_state.config.prompt.initial_changes_description,
            language=task_state.config.language,
            metrics=result.metrics,
            iteration_found=0,
        )
        task_state.database.add(initial_program, iteration=0)
        if result.artifacts:
            task_state.database.store_artifacts(initial_program_id, result.artifacts)

        logger.info(
            "Initialized task '%s' with initial metrics: %s",
            task_state.task_name,
            format_metrics_safe(result.metrics),
        )

    def _commit_parallel_task_result(
        self,
        task_state: TaskState,
        result: TaskIterationWorkerResult,
    ) -> TaskIterationResult:
        self._apply_worker_rng_state(task_state, result.rng_state)

        if not result.success:
            task_state.local_iteration = result.local_iteration
            logger.warning(
                "Task '%s' iteration %d discarded: %s",
                task_state.task_name,
                result.local_iteration,
                result.failure_reason,
            )
            return TaskIterationResult(
                task_name=task_state.task_name,
                local_iteration=result.local_iteration,
                success=False,
                failure_reason=result.failure_reason,
                generation_time_sec=result.generation_time_sec,
                evaluation_time_sec=result.evaluation_time_sec,
                iteration_time_sec=result.iteration_time_sec,
                foreign_inspiration_sources=result.foreign_inspiration_sources,
                foreign_transfer_trigger_reason=result.foreign_transfer_trigger_reason,
                chosen_transfer_arm=result.chosen_transfer_arm,
            )

        if not result.child_program_dict:
            raise RuntimeError(
                f"Task '{task_state.task_name}' reported success without a child program"
            )

        child_program = Program.from_dict(result.child_program_dict)
        parent_program = (
            task_state.database.get(child_program.parent_id) if child_program.parent_id else None
        )
        template_key = (
            "full_rewrite_user"
            if not task_state.config.diff_based_evolution
            else "diff_user"
        )

        with self._activate_task_context(task_state):
            task_state.database.add(
                child_program,
                iteration=result.local_iteration,
                target_island=result.target_island,
            )
            if result.artifacts:
                task_state.database.store_artifacts(child_program.id, result.artifacts)
            if task_state.database.config.log_prompts and result.prompt:
                task_state.database.log_prompt(
                    program_id=child_program.id,
                    template_key=template_key,
                    prompt=dict(result.prompt),
                    responses=[result.llm_response] if result.llm_response else [],
                )

            task_state.database.increment_island_generation(island_idx=result.target_island)
            if task_state.database.should_migrate():
                logger.info(
                    "Task '%s' performing intra-task migration at local iteration %d",
                    task_state.task_name,
                    result.local_iteration,
                )
                task_state.database.migrate_programs()

        task_state.local_iteration = result.local_iteration

        if task_state.evolution_tracer and parent_program is not None:
            task_state.evolution_tracer.log_trace(
                iteration=result.local_iteration,
                parent_program=parent_program,
                child_program=child_program,
                prompt=result.prompt,
                llm_response=result.llm_response,
                artifacts=result.artifacts,
                island_id=result.target_island,
                metadata={
                    "iteration_time": result.iteration_time_sec,
                    "changes": child_program.metadata.get("changes", ""),
                    "task_name": task_state.task_name,
                    "wave": result.local_iteration,
                },
            )

        logger.info(
            "Task '%s' wave %d completed in %.2fs: %s",
            task_state.task_name,
            result.local_iteration,
            result.iteration_time_sec or 0.0,
            format_metrics_safe(child_program.metrics),
        )

        if (
            "combined_score" not in child_program.metrics
            and not hasattr(task_state, "_warned_about_combined_score")
        ):
            setattr(task_state, "_warned_about_combined_score", True)
            logger.warning(
                "Task '%s' returned no combined_score; using safe numeric average %.4f for ranking guidance",
                task_state.task_name,
                safe_numeric_average(child_program.metrics),
            )

        if task_state.database.best_program_id == child_program.id:
            logger.info(
                "Task '%s' found a new best program at local iteration %d",
                task_state.task_name,
                result.local_iteration,
            )

        return TaskIterationResult(
            task_name=task_state.task_name,
            local_iteration=result.local_iteration,
            success=True,
            child_program=child_program,
            generation_time_sec=result.generation_time_sec,
            evaluation_time_sec=result.evaluation_time_sec,
            iteration_time_sec=result.iteration_time_sec,
            foreign_inspiration_sources=result.foreign_inspiration_sources,
            foreign_transfer_trigger_reason=result.foreign_transfer_trigger_reason,
            chosen_transfer_arm=result.chosen_transfer_arm,
        )

    def _validate_wave_lockstep(self) -> None:
        expected_iteration = self.completed_waves
        task_iterations = {task.task_name: task.local_iteration for task in self.tasks}
        if len(set(task_iterations.values())) != 1:
            raise RuntimeError(
                "Parallel multitask wave execution requires all tasks to stay in lockstep; "
                f"found task iterations {task_iterations}"
            )
        if any(iteration != expected_iteration for iteration in task_iterations.values()):
            raise RuntimeError(
                "Parallel multitask resume state is inconsistent with completed_waves; "
                f"expected {expected_iteration}, found {task_iterations}"
            )

    def _log_parallel_wave(
        self,
        *,
        wave_index: int,
        committed_results: List[TaskIterationResult],
    ) -> None:
        metrics: Dict[str, Any] = {
            "multitask/global_iteration": self.completed_global_iterations,
            "multitask/completed_waves": wave_index,
            "multitask/num_active_tasks": len(self.tasks),
            "multitask/wave_successes": sum(1 for result in committed_results if result.success),
            "multitask/wave_failures": sum(1 for result in committed_results if not result.success),
        }

        for result in committed_results:
            task_state = self.task_by_name[result.task_name]
            task_name = task_state.task_name
            task_prefix = f"task/{task_name}"
            previous_best = self._task_best_fitness.get(task_name)
            current_task_best = self._get_task_best_fitness(task_state)
            if current_task_best is not None:
                self._task_best_fitness[task_name] = current_task_best

            foreign_transfer_used = bool(result.foreign_inspiration_sources)
            progress_update = self._update_task_progress_state(
                task_state=task_state,
                previous_best=previous_best,
                current_best=current_task_best,
                local_iteration=result.local_iteration,
                foreign_transfer_used=foreign_transfer_used,
                chosen_transfer_arm=result.chosen_transfer_arm,
                child_program=result.child_program,
            )
            iterations_since_improvement = self._iterations_since_improvement(task_state)
            stagnating = self._is_task_stagnating(task_state)
            stagnation_ready = self._is_stagnation_trigger_ready(task_state)
            transfer_trigger_reason = result.foreign_transfer_trigger_reason or ""

            current_fitness = None
            if result.child_program is not None:
                current_fitness = get_fitness_score(
                    result.child_program.metrics,
                    task_state.database.config.feature_dimensions,
                )

            metrics.update(
                {
                    f"{task_prefix}/task_local_iteration": task_state.local_iteration,
                    f"{task_prefix}/current_fitness": current_fitness,
                    f"{task_prefix}/best_fitness": current_task_best,
                    f"{task_prefix}/delta_best_fitness": progress_update.delta_best,
                    f"{task_prefix}/evaluation_success": 1 if result.success else 0,
                    f"{task_prefix}/iteration_time_sec": result.iteration_time_sec,
                    f"{task_prefix}/generation_time_sec": result.generation_time_sec,
                    f"{task_prefix}/evaluation_time_sec": result.evaluation_time_sec,
                    f"{task_prefix}/failure_reason": result.failure_reason,
                    f"{task_prefix}/no_improve_steps": task_state.no_improve_steps,
                    f"{task_prefix}/trigger_mode": self.multitask_config.foreign_inspirations.trigger_mode,
                    f"{task_prefix}/last_improvement_iteration": (
                        task_state.last_improvement_iteration
                        if task_state.last_improvement_iteration is not None
                        else -1
                    ),
                    f"{task_prefix}/last_transfer_iteration": (
                        task_state.last_transfer_iteration
                        if task_state.last_transfer_iteration is not None
                        else -1
                    ),
                    f"{task_prefix}/iterations_since_improvement": iterations_since_improvement,
                    f"{task_prefix}/stagnating": 1 if stagnating else 0,
                    f"{task_prefix}/stagnation_ready": 1 if stagnation_ready else 0,
                    f"{task_prefix}/foreign_inspirations_used": 1 if foreign_transfer_used else 0,
                    f"{task_prefix}/foreign_transfer_used": 1 if foreign_transfer_used else 0,
                    f"{task_prefix}/foreign_transfer_triggered": 1 if foreign_transfer_used else 0,
                    f"{task_prefix}/foreign_transfer_trigger_reason": transfer_trigger_reason,
                    f"{task_prefix}/chosen_transfer_arm": result.chosen_transfer_arm or "",
                    f"{task_prefix}/foreign_transfer_reward": (
                        progress_update.reward_for_chosen_arm
                        if progress_update.reward_for_chosen_arm is not None
                        else -1
                    ),
                    f"{task_prefix}/num_foreign_inspirations": len(
                        result.foreign_inspiration_sources
                    ),
                    f"{task_prefix}/foreign_inspiration_sources": ",".join(
                        result.foreign_inspiration_sources
                    ),
                }
            )
            if progress_update.reward_mode is not None:
                metrics.update(
                    {
                        f"{task_prefix}/bandit_reward_mode": progress_update.reward_mode,
                        f"{task_prefix}/bandit_reward_child_fitness": (
                            progress_update.child_fitness_for_reward
                        ),
                        f"{task_prefix}/bandit_reward_baseline": (
                            progress_update.reward_baseline_fitness
                        ),
                    }
                )
            self._add_transfer_bandit_metrics(metrics, task_prefix, task_state)
            logger.info(
                "Task '%s' committed transfer state at wave %d: mode=%s no_improve_steps=%d "
                "chosen_arm=%s foreign_used=%s reward_mode=%s reward=%s child_fitness=%s "
                "reward_baseline=%s last_transfer_iteration=%s bandit_scores=%s",
                task_name,
                result.local_iteration,
                self.multitask_config.foreign_inspirations.trigger_mode,
                task_state.no_improve_steps,
                result.chosen_transfer_arm,
                foreign_transfer_used,
                progress_update.reward_mode,
                progress_update.reward_for_chosen_arm,
                progress_update.child_fitness_for_reward,
                progress_update.reward_baseline_fitness,
                task_state.last_transfer_iteration,
                self._compute_transfer_bandit_scores(task_state)
                if self.multitask_config.foreign_inspirations.trigger_mode == "online_bandit"
                else {},
            )

            if result.child_program is not None:
                metrics.update(
                    {
                        key: value
                        for key, value in flatten_scalars(
                            result.child_program.metrics,
                            prefix=task_prefix,
                        ).items()
                        if key not in metrics
                    }
                )
                if task_state.database.best_program_id == result.child_program.id:
                    self._log_task_best_program_artifact(task_state, result.child_program)
                    self.wandb_logger.update_summary(
                        {
                            f"{task_prefix}/best_fitness": current_task_best,
                            f"{task_prefix}/best_iteration": result.child_program.iteration_found,
                        }
                    )

        metrics.update(self._aggregate_best_fitness())
        self.wandb_logger.log_metrics(metrics, step=self.completed_global_iterations)

    def _save_checkpoint(self, wave_index: int) -> str:
        checkpoint_root = (
            Path(self.output_dir) / "checkpoints" / f"checkpoint_wave_{wave_index:04d}"
        )
        checkpoint_root.mkdir(parents=True, exist_ok=True)
        tasks_checkpoint_root = checkpoint_root / "tasks"
        tasks_checkpoint_root.mkdir(parents=True, exist_ok=True)

        task_checkpoints: Dict[str, str] = {}
        task_iterations: Dict[str, int] = {}
        task_random_states: Dict[str, str] = {}
        task_numpy_random_states: Dict[str, str] = {}
        llm_random_states: Dict[str, str] = {}
        evaluator_llm_random_states: Dict[str, str] = {}

        for task_state in self.tasks:
            task_checkpoint_relative_path = Path("tasks") / task_state.task_name
            task_checkpoint_path = checkpoint_root / task_checkpoint_relative_path
            task_checkpoint_path.mkdir(parents=True, exist_ok=True)
            task_state.database.save(str(task_checkpoint_path), iteration=task_state.local_iteration)
            self._snapshot_task_artifacts(task_state, task_checkpoint_path)
            self._save_task_best_program(task_state, base_dir=str(task_checkpoint_path))

            with open(task_checkpoint_path / "task_state.json", "w") as handle:
                json.dump(
                    {
                        "task_name": task_state.task_name,
                        "local_iteration": task_state.local_iteration,
                        "wave": wave_index,
                        "completed_global_iterations": self.completed_global_iterations,
                    },
                    handle,
                    indent=2,
                )

            task_state.checkpoint_metadata = {
                "path": str(task_checkpoint_path),
                "global_iteration": self.completed_global_iterations,
                "local_iteration": task_state.local_iteration,
                "wave": wave_index,
            }
            task_checkpoints[task_state.task_name] = str(task_checkpoint_relative_path)
            task_iterations[task_state.task_name] = task_state.local_iteration
            task_random_states[task_state.task_name] = _serialize_state(task_state.random_state)
            task_numpy_random_states[task_state.task_name] = _serialize_state(
                task_state.numpy_random_state
            )
            llm_random_states[task_state.task_name] = _serialize_state(
                task_state.llm_ensemble.random_state.getstate()
            )
            evaluator_llm_random_states[task_state.task_name] = _serialize_state(
                task_state.llm_evaluator_ensemble.random_state.getstate()
            )

        with open(checkpoint_root / "multitask_state.json", "w") as handle:
            json.dump(
                {
                    "execution_mode": "parallel_synchronized_waves",
                    "completed_waves": wave_index,
                    "completed_global_iterations": self.completed_global_iterations,
                    "task_iterations": task_iterations,
                    "task_no_improve_steps": {
                        task.task_name: task.no_improve_steps for task in self.tasks
                    },
                    "task_last_improvement_iterations": {
                        task.task_name: task.last_improvement_iteration for task in self.tasks
                    },
                    "task_last_transfer_iterations": {
                        task.task_name: task.last_transfer_iteration for task in self.tasks
                    },
                    "task_transfer_bandit_alpha": {
                        task.task_name: task.transfer_bandit_alpha for task in self.tasks
                    },
                    "task_transfer_bandit_beta": {
                        task.task_name: task.transfer_bandit_beta for task in self.tasks
                    },
                    "task_transfer_bandit_pulls": {
                        task.task_name: task.transfer_bandit_pulls for task in self.tasks
                    },
                    "task_recent_child_fitness_history": {
                        task.task_name: task.recent_child_fitness_history for task in self.tasks
                    },
                    "task_checkpoints": task_checkpoints,
                    "task_random_states": task_random_states,
                    "task_numpy_random_states": task_numpy_random_states,
                    "llm_random_states": llm_random_states,
                    "evaluator_llm_random_states": evaluator_llm_random_states,
                },
                handle,
                indent=2,
            )

        with open(checkpoint_root / "multitask_config_snapshot.json", "w") as handle:
            json.dump(
                {
                    "saved_at": time.time(),
                    "multitask_config": _normalize_json_compatible(asdict(self.multitask_config)),
                    "resume_validation": {
                        "hash": self._resume_validation_hash,
                        "payload": self._resume_validation_payload,
                    },
                },
                handle,
                indent=2,
            )

        self._last_checkpoint_wave = wave_index
        self._last_checkpoint_iteration = self.completed_global_iterations
        logger.info(
            "Saved synchronized-wave checkpoint at wave %d (global iteration %d)",
            wave_index,
            self.completed_global_iterations,
        )
        self.wandb_logger.log_checkpoint_artifact(
            str(checkpoint_root),
            metadata={
                "wave": wave_index,
                "global_iteration": self.completed_global_iterations,
            },
        )
        return str(checkpoint_root)

    def _load_checkpoint(self, checkpoint_path: str, force_resume: bool = False) -> None:
        checkpoint_root = Path(checkpoint_path).resolve()
        self._validate_checkpoint_config_snapshot(
            checkpoint_root=checkpoint_root, force_resume=force_resume
        )
        state_path = checkpoint_root / "multitask_state.json"
        if not state_path.exists():
            raise FileNotFoundError(f"Multitask checkpoint metadata not found: {state_path}")

        with open(state_path, "r") as handle:
            state = json.load(handle)

        self.completed_waves = state.get("completed_waves", 0)
        self.completed_global_iterations = state.get(
            "completed_global_iterations",
            self.completed_waves * len(self.tasks),
        )
        expected_global_iterations = self.completed_waves * len(self.tasks)
        if self.completed_global_iterations != expected_global_iterations:
            raise ValueError(
                "Parallel multitask checkpoint has inconsistent completed_global_iterations: "
                f"expected {expected_global_iterations}, found {self.completed_global_iterations}"
            )
        self._last_checkpoint_wave = self.completed_waves
        self._last_checkpoint_iteration = self.completed_global_iterations

        task_checkpoints = state.get("task_checkpoints", {})
        task_iterations = state.get("task_iterations", {})
        task_no_improve_steps = state.get("task_no_improve_steps", {})
        task_last_improvement_iterations = state.get("task_last_improvement_iterations", {})
        task_last_transfer_iterations = state.get("task_last_transfer_iterations", {})
        task_transfer_bandit_alpha = state.get("task_transfer_bandit_alpha", {})
        task_transfer_bandit_beta = state.get("task_transfer_bandit_beta", {})
        task_transfer_bandit_pulls = state.get("task_transfer_bandit_pulls", {})
        task_recent_child_fitness_history = state.get("task_recent_child_fitness_history", {})
        task_random_states = state.get("task_random_states", {})
        task_numpy_random_states = state.get("task_numpy_random_states", {})
        llm_random_states = state.get("llm_random_states", {})
        evaluator_llm_random_states = state.get("evaluator_llm_random_states", {})

        for task_state in self.tasks:
            task_checkpoint_path = task_checkpoints.get(task_state.task_name)
            if not task_checkpoint_path:
                raise ValueError(
                    f"Checkpoint is missing state for task '{task_state.task_name}'"
                )
            resolved_task_checkpoint_path = Path(task_checkpoint_path)
            if not resolved_task_checkpoint_path.is_absolute():
                resolved_task_checkpoint_path = checkpoint_root / resolved_task_checkpoint_path
            resolved_task_checkpoint_path = resolved_task_checkpoint_path.resolve()

            task_state.database.load(str(resolved_task_checkpoint_path))
            task_state.local_iteration = task_iterations.get(
                task_state.task_name, task_state.database.last_iteration
            )
            task_state.no_improve_steps = task_no_improve_steps.get(task_state.task_name, 0)
            task_state.last_improvement_iteration = task_last_improvement_iterations.get(
                task_state.task_name
            )
            task_state.last_transfer_iteration = task_last_transfer_iterations.get(
                task_state.task_name
            )
            task_state.transfer_bandit_alpha = {
                str(arm_name): float(value)
                for arm_name, value in (
                    task_transfer_bandit_alpha.get(task_state.task_name) or {}
                ).items()
            }
            task_state.transfer_bandit_beta = {
                str(arm_name): float(value)
                for arm_name, value in (
                    task_transfer_bandit_beta.get(task_state.task_name) or {}
                ).items()
            }
            task_state.transfer_bandit_pulls = {
                str(arm_name): int(value)
                for arm_name, value in (
                    task_transfer_bandit_pulls.get(task_state.task_name) or {}
                ).items()
            }
            task_state.recent_child_fitness_history = _normalize_recent_child_fitness_history(
                task_recent_child_fitness_history.get(task_state.task_name)
            )
            self._ensure_transfer_bandit_state(task_state)
            self._trim_recent_child_fitness_history(task_state)
            task_state.random_state = _deserialize_state(
                task_random_states[task_state.task_name]
            )
            task_state.numpy_random_state = _deserialize_state(
                task_numpy_random_states[task_state.task_name]
            )
            task_state.llm_ensemble.random_state.setstate(
                _deserialize_state(llm_random_states[task_state.task_name])
            )
            task_state.llm_evaluator_ensemble.random_state.setstate(
                _deserialize_state(evaluator_llm_random_states[task_state.task_name])
            )
            task_state.checkpoint_metadata = {
                "path": str(resolved_task_checkpoint_path),
                "global_iteration": self.completed_global_iterations,
                "local_iteration": task_state.local_iteration,
                "wave": self.completed_waves,
            }

        self._validate_wave_lockstep()
        logger.info(
            "Loaded synchronized-wave checkpoint from %s at wave %d",
            checkpoint_root,
            self.completed_waves,
        )

    def _write_summary(self, *, wall_clock_time_sec: Optional[float]) -> Dict[str, Optional[Program]]:
        summary: Dict[str, Any] = {
            "execution_mode": "parallel_synchronized_waves",
            "completed_global_iterations": self.completed_global_iterations,
            "completed_waves": self.completed_waves,
            "wall_clock_time_sec": wall_clock_time_sec,
            "scheduler_counts": dict(self._scheduler_counts),
            "tasks": {},
        }
        best_programs: Dict[str, Optional[Program]] = {}

        for task_state in self.tasks:
            best_program = task_state.database.get_best_program()
            best_programs[task_state.task_name] = best_program
            best_fitness = self._get_task_best_fitness(task_state)
            summary["tasks"][task_state.task_name] = {
                "output_dir": task_state.output_dir,
                "local_iteration": task_state.local_iteration,
                "best_program_id": best_program.id if best_program else None,
                "best_iteration": best_program.iteration_found if best_program else None,
                "best_fitness": best_fitness,
                "best_metrics": best_program.metrics if best_program else None,
            }

        summary.update(self._aggregate_best_fitness())

        with open(Path(self.output_dir) / "summary.json", "w") as handle:
            json.dump(summary, handle, indent=2)

        return best_programs

    async def run(
        self,
        checkpoint_path: Optional[str] = None,
        max_global_iterations: Optional[int] = None,
        force_resume: bool = False,
        max_waves: Optional[int] = None,
    ) -> Dict[str, Optional[Program]]:
        if max_global_iterations is not None:
            raise ValueError(
                "max_global_iterations/--iterations is only supported for "
                "execution_mode='sequential_round_robin'"
            )

        total_waves = (
            max_waves if max_waves is not None else self.multitask_config.max_waves
        )
        if total_waves is None:
            raise ValueError("multitask.max_waves must be set for synchronized-wave execution")

        self._run_started_at = time.time()
        self.wandb_logger.init_run(
            run_mode="multitask",
            config_payload={
                "multitask": self.multitask_config,
                "base_config": self.base_config,
            },
            metadata={
                **self._build_wandb_run_metadata(
                    checkpoint_path=checkpoint_path,
                    max_global_iterations=total_waves * len(self.tasks),
                ),
                "execution_mode": "parallel_synchronized_waves",
                "requested_waves": total_waves,
            },
            step_metric="multitask/global_iteration",
        )

        try:
            if checkpoint_path:
                self._load_checkpoint(checkpoint_path, force_resume=force_resume)

            self._start_task_workers()

            if not checkpoint_path:
                for task_state in self.tasks:
                    await self._ensure_initial_program(task_state)

            self._initialize_task_progress_state(from_checkpoint=bool(checkpoint_path))

            self._log_multitask_initial_state()

            while self.completed_waves < total_waves:
                self._validate_wave_lockstep()
                next_wave = self.completed_waves + 1
                logger.info("Starting synchronized wave %d", next_wave)

                frozen_transfer_states = self._snapshot_task_transfer_states()
                requests: Dict[str, TaskIterationRequest] = {}
                for task_state in self.tasks:
                    requests[task_state.task_name] = self._prepare_task_iteration_request(
                        task_state,
                        frozen_transfer_states=frozen_transfer_states,
                    )
                    logger.info(
                        "Prepared wave %d task '%s' with foreign sources %s",
                        next_wave,
                        task_state.task_name,
                        [
                            source.get("source_task")
                            for source in requests[task_state.task_name].foreign_inspirations
                            if source.get("source_task")
                        ],
                    )

                futures = {
                    task_state.task_name: self._task_workers[task_state.task_name].submit_iteration(
                        requests[task_state.task_name]
                    )
                    for task_state in self.tasks
                }
                wrapped_futures = {
                    task_name: asyncio.wrap_future(future)
                    for task_name, future in futures.items()
                }
                worker_results = await asyncio.gather(
                    *wrapped_futures.values(),
                    return_exceptions=True,
                )

                exceptions = [
                    result for result in worker_results if isinstance(result, BaseException)
                ]
                if exceptions:
                    first_error = exceptions[0]
                    if isinstance(first_error, BrokenProcessPool):
                        message = (
                            f"Synchronized wave {next_wave} aborted before commit due to a "
                            "broken task worker process pool"
                        )
                    else:
                        message = (
                            f"Synchronized wave {next_wave} aborted before commit due to "
                            f"infrastructure failure: {first_error}"
                        )
                    raise RuntimeError(message) from first_error

                results_by_task = {
                    task_name: result
                    for task_name, result in zip(wrapped_futures.keys(), worker_results)
                }
                committed_results: List[TaskIterationResult] = []
                for task_state in self.tasks:
                    self._scheduler_counts[task_state.task_name] = (
                        self._scheduler_counts.get(task_state.task_name, 0) + 1
                    )
                    committed_results.append(
                        self._commit_parallel_task_result(
                            task_state,
                            results_by_task[task_state.task_name],
                        )
                    )

                self.completed_waves = next_wave
                self.completed_global_iterations = self.completed_waves * len(self.tasks)
                self._validate_wave_lockstep()
                self._log_parallel_wave(
                    wave_index=self.completed_waves,
                    committed_results=committed_results,
                )

                checkpoint_written = False
                if self.completed_waves % self.multitask_config.checkpoint_every_waves == 0:
                    self._save_checkpoint(self.completed_waves)
                    checkpoint_written = True

                logger.info(
                    "Committed synchronized wave %d: %s (checkpoint=%s)",
                    self.completed_waves,
                    ", ".join(
                        [
                            (
                                f"{result.task_name}=success"
                                if result.success
                                else f"{result.task_name}=failure({result.failure_reason})"
                            )
                            for result in committed_results
                        ]
                    ),
                    checkpoint_written,
                )

            if self.completed_waves > 0 and self._last_checkpoint_wave != self.completed_waves:
                self._save_checkpoint(self.completed_waves)

            for task_state in self.tasks:
                self._save_task_best_program(task_state)
                best_program = task_state.database.get_best_program()
                if best_program is not None:
                    self._log_task_best_program_artifact(task_state, best_program)

            wall_clock_time_sec = time.time() - self._run_started_at if self._run_started_at else None
            best_programs = self._write_summary(wall_clock_time_sec=wall_clock_time_sec)
            self.wandb_logger.update_summary(
                {
                    "completed_global_iterations": self.completed_global_iterations,
                    "completed_waves": self.completed_waves,
                    "wall_clock_time_sec": wall_clock_time_sec,
                    "scheduler_counts": dict(self._scheduler_counts),
                    **self._aggregate_best_fitness(),
                }
            )
            summary_path = Path(self.output_dir) / "summary.json"
            self.wandb_logger.log_file_artifact(
                str(summary_path),
                artifact_name="final-summary",
                artifact_type="summary",
                metadata={
                    "completed_global_iterations": self.completed_global_iterations,
                    "completed_waves": self.completed_waves,
                    "wall_clock_time_sec": wall_clock_time_sec,
                },
            )
            return best_programs
        finally:
            self._stop_task_workers()
            self._close_tracers()
            self.wandb_logger.finish()


SequentialRoundRobinMultiTaskOpenEvolve = MultiTaskOpenEvolve


def create_multitask_controller(multitask_config: MultitaskConfig) -> MultiTaskOpenEvolve:
    """Instantiate the multitask controller requested by config.execution_mode."""
    if multitask_config.execution_mode == "parallel_synchronized_waves":
        return ParallelWaveMultiTaskOpenEvolve(multitask_config)
    return MultiTaskOpenEvolve(multitask_config)
