"""Dedicated workers for synchronized-wave multitask execution."""

from __future__ import annotations

import asyncio
import logging
import multiprocessing as mp
import os
import random
import sys
import time
import uuid
from concurrent.futures import Future, ProcessPoolExecutor
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import numpy as np

from openevolve.config import Config, PromptConfig
from openevolve.database import Program
from openevolve.evaluator import Evaluator
from openevolve.llm.ensemble import LLMEnsemble
from openevolve.prompt.sampler import PromptSampler
from openevolve.utils.code_utils import (
    apply_diff,
    apply_diff_blocks,
    extract_diffs,
    format_diff_summary,
    parse_full_rewrite,
    split_diffs_by_target,
)

logger = logging.getLogger(__name__)

_worker_config: Optional[Config] = None
_worker_evaluation_file: Optional[str] = None
_worker_prompt_sampler: Optional[PromptSampler] = None
_worker_evaluator_prompt_sampler: Optional[PromptSampler] = None
_worker_llm_ensemble: Optional[LLMEnsemble] = None
_worker_evaluator_llm_ensemble: Optional[LLMEnsemble] = None
_worker_evaluator: Optional[Evaluator] = None


class _FixedTaskNameFilter(logging.Filter):
    """Stamp a fixed task label onto worker-process log records."""

    def __init__(self, task_name: str):
        super().__init__()
        self._task_name = task_name

    def filter(self, record: logging.LogRecord) -> bool:
        if not getattr(record, "task_name", None):
            record.task_name = self._task_name
        return True


def _format_worker_exception(exc: BaseException) -> str:
    """Convert worker exceptions into plain strings safe to send across processes."""
    message = str(exc).strip()
    if message:
        return f"{type(exc).__name__}: {message}"
    return type(exc).__name__


def _safe_capture_rng_state(fallback: WorkerRngState) -> WorkerRngState:
    """Capture worker RNG state when possible, otherwise preserve the caller state."""
    try:
        return _capture_rng_state()
    except Exception:
        return fallback


@dataclass
class WorkerRngState:
    """Authoritative task RNG state transferred between main and worker."""

    python_random_state: Any
    numpy_random_state: Any
    llm_random_state: Any
    evaluator_llm_random_state: Any


@dataclass
class InitialProgramEvaluationRequest:
    """Request to evaluate a task's initial program inside its dedicated worker."""

    program_id: str
    code: str
    rng_state: WorkerRngState


@dataclass
class InitialProgramEvaluationResult:
    """Serializable result for initial-program evaluation."""

    metrics: Dict[str, Any]
    artifacts: Dict[str, Any]
    rng_state: WorkerRngState


@dataclass
class TaskIterationRequest:
    """Serializable work item for one task attempt in one synchronized wave."""

    task_name: str
    local_iteration: int
    target_island: int
    parent_program: Dict[str, Any]
    inspirations: List[Dict[str, Any]]
    previous_programs: List[Dict[str, Any]]
    top_programs: List[Dict[str, Any]]
    parent_artifacts: Dict[str, Any]
    foreign_inspirations: List[Dict[str, Any]]
    feature_dimensions: List[str]
    rng_state: WorkerRngState
    foreign_transfer_trigger_reason: Optional[str] = None
    chosen_transfer_arm: Optional[str] = None
    effective_prompt_config: Optional[Dict[str, Any]] = None


@dataclass
class TaskIterationWorkerResult:
    """Serializable worker result for a synchronized-wave task attempt."""

    task_name: str
    local_iteration: int
    success: bool
    target_island: int
    rng_state: WorkerRngState
    foreign_inspiration_sources: List[str]
    child_program_dict: Optional[Dict[str, Any]] = None
    parent_id: Optional[str] = None
    failure_reason: Optional[str] = None
    generation_time_sec: Optional[float] = None
    evaluation_time_sec: Optional[float] = None
    iteration_time_sec: Optional[float] = None
    prompt: Optional[Dict[str, str]] = None
    llm_response: Optional[str] = None
    artifacts: Optional[Dict[str, Any]] = None
    foreign_transfer_trigger_reason: Optional[str] = None
    chosen_transfer_arm: Optional[str] = None


def serialize_task_config(config: Config) -> Dict[str, Any]:
    """Serialize a task config for worker bootstrap without shared live objects."""
    novelty_llm = config.database.novelty_llm
    try:
        config.database.novelty_llm = None
        return config.to_dict()
    finally:
        config.database.novelty_llm = novelty_llm


def _worker_init(
    config_dict: Dict[str, Any],
    evaluation_file: str,
    parent_env: Dict[str, str],
    task_env: Dict[str, str],
    task_name: str,
    worker_log_path: Optional[str],
    log_level_name: str,
) -> None:
    """Initialize one dedicated task worker process."""
    root_logger = logging.getLogger()
    for handler in list(root_logger.handlers):
        root_logger.removeHandler(handler)

    if worker_log_path:
        os.makedirs(os.path.dirname(worker_log_path), exist_ok=True)
        file_handler = logging.FileHandler(worker_log_path)
        file_handler.addFilter(_FixedTaskNameFilter(task_name))
        file_handler.setFormatter(
            logging.Formatter(
                "%(asctime)s - %(processName)s[%(process)d] - %(name)s - "
                "%(levelname)s - [%(task_name)s] %(message)s"
            )
        )
        root_logger.addHandler(file_handler)
        root_logger.setLevel(getattr(logging, str(log_level_name).upper(), logging.INFO))
    else:
        root_logger.addHandler(logging.NullHandler())
        root_logger.setLevel(logging.CRITICAL)

    os.environ.update(parent_env)
    os.environ.update(task_env)

    global _worker_config
    global _worker_evaluation_file
    global _worker_prompt_sampler
    global _worker_evaluator_prompt_sampler
    global _worker_llm_ensemble
    global _worker_evaluator_llm_ensemble
    global _worker_evaluator

    _worker_config = Config.from_dict(config_dict)
    _worker_evaluation_file = evaluation_file
    _worker_prompt_sampler = None
    _worker_evaluator_prompt_sampler = None
    _worker_llm_ensemble = None
    _worker_evaluator_llm_ensemble = None
    _worker_evaluator = None


def _lazy_init_worker_components() -> None:
    global _worker_prompt_sampler
    global _worker_evaluator_prompt_sampler
    global _worker_llm_ensemble
    global _worker_evaluator_llm_ensemble
    global _worker_evaluator

    if _worker_config is None or _worker_evaluation_file is None:
        raise RuntimeError("Dedicated multitask worker used before initialization")

    if _worker_prompt_sampler is None:
        _worker_prompt_sampler = PromptSampler(_worker_config.prompt)

    if _worker_evaluator_prompt_sampler is None:
        _worker_evaluator_prompt_sampler = PromptSampler(_worker_config.prompt)
        _worker_evaluator_prompt_sampler.set_templates("evaluator_system_message")

    if _worker_llm_ensemble is None:
        _worker_llm_ensemble = LLMEnsemble(_worker_config.llm.models)

    if _worker_evaluator_llm_ensemble is None:
        _worker_evaluator_llm_ensemble = LLMEnsemble(_worker_config.llm.evaluator_models)

    if _worker_evaluator is None:
        _worker_evaluator = Evaluator(
            _worker_config.evaluator,
            _worker_evaluation_file,
            _worker_evaluator_llm_ensemble,
            _worker_evaluator_prompt_sampler,
            database=None,
            suffix=_worker_config.file_suffix,
        )


def _restore_rng_state(rng_state: WorkerRngState) -> None:
    _lazy_init_worker_components()
    random.setstate(rng_state.python_random_state)
    np.random.set_state(rng_state.numpy_random_state)
    if _worker_llm_ensemble is None or _worker_evaluator_llm_ensemble is None:
        raise RuntimeError("Dedicated multitask worker LLMs are not initialized")
    _worker_llm_ensemble.random_state.setstate(rng_state.llm_random_state)
    _worker_evaluator_llm_ensemble.random_state.setstate(rng_state.evaluator_llm_random_state)


def _capture_rng_state() -> WorkerRngState:
    if _worker_llm_ensemble is None or _worker_evaluator_llm_ensemble is None:
        raise RuntimeError("Dedicated multitask worker LLMs are not initialized")

    return WorkerRngState(
        python_random_state=random.getstate(),
        numpy_random_state=np.random.get_state(),
        llm_random_state=_worker_llm_ensemble.random_state.getstate(),
        evaluator_llm_random_state=_worker_evaluator_llm_ensemble.random_state.getstate(),
    )


def _get_effective_prompt_config(request: TaskIterationRequest) -> PromptConfig:
    if _worker_config is None:
        raise RuntimeError("Dedicated multitask worker used before initialization")
    if request.effective_prompt_config is None:
        return _worker_config.prompt
    return PromptConfig(**request.effective_prompt_config)


def _get_iteration_prompt_sampler(prompt_config: PromptConfig) -> PromptSampler:
    if _worker_prompt_sampler is None or _worker_config is None:
        raise RuntimeError("Dedicated multitask worker prompt sampler is not initialized")
    if (
        prompt_config.num_top_programs == _worker_config.prompt.num_top_programs
        and prompt_config.num_diverse_programs == _worker_config.prompt.num_diverse_programs
    ):
        return _worker_prompt_sampler
    return PromptSampler(prompt_config)


def run_initial_program_evaluation(
    request: InitialProgramEvaluationRequest,
) -> InitialProgramEvaluationResult:
    """Evaluate a task's initial program within its dedicated worker."""
    try:
        _restore_rng_state(request.rng_state)

        if _worker_evaluator is None:
            raise RuntimeError("Dedicated multitask worker evaluator is not initialized")

        metrics = asyncio.run(_worker_evaluator.evaluate_program(request.code, request.program_id))
        artifacts = _worker_evaluator.get_pending_artifacts(request.program_id)
        return InitialProgramEvaluationResult(
            metrics=metrics,
            artifacts=artifacts,
            rng_state=_safe_capture_rng_state(request.rng_state),
        )
    except Exception as exc:
        raise RuntimeError(
            f"initial program evaluation failed: {_format_worker_exception(exc)}"
        ) from None


def _normal_failure(
    request: TaskIterationRequest,
    *,
    reason: str,
    iteration_start: float,
    rng_state: Optional[WorkerRngState] = None,
    generation_time_sec: Optional[float] = None,
    evaluation_time_sec: Optional[float] = None,
    prompt: Optional[Dict[str, str]] = None,
    llm_response: Optional[str] = None,
) -> TaskIterationWorkerResult:
    return TaskIterationWorkerResult(
        task_name=request.task_name,
        local_iteration=request.local_iteration,
        success=False,
        target_island=request.target_island,
        rng_state=_safe_capture_rng_state(rng_state or request.rng_state),
        foreign_inspiration_sources=[
            source["source_task"]
            for source in request.foreign_inspirations
            if source.get("source_task")
        ],
        failure_reason=reason,
        generation_time_sec=generation_time_sec,
        evaluation_time_sec=evaluation_time_sec,
        iteration_time_sec=time.time() - iteration_start,
        prompt=prompt,
        llm_response=llm_response,
        artifacts=None,
        foreign_transfer_trigger_reason=request.foreign_transfer_trigger_reason,
        chosen_transfer_arm=request.chosen_transfer_arm,
    )


def run_task_iteration(request: TaskIterationRequest) -> TaskIterationWorkerResult:
    """Execute one attempted task iteration inside a dedicated worker."""
    iteration_start = time.time()
    generation_time: Optional[float] = None
    evaluation_time: Optional[float] = None
    prompt: Optional[Dict[str, str]] = None
    llm_response: Optional[str] = None
    stage = "worker initialization"

    try:
        _restore_rng_state(request.rng_state)
        stage = "component validation"

        if _worker_config is None or _worker_prompt_sampler is None or _worker_evaluator is None:
            raise RuntimeError("Dedicated multitask worker is not initialized")

        parent = Program.from_dict(request.parent_program)
        effective_prompt_config = _get_effective_prompt_config(request)
        prompt_sampler = _get_iteration_prompt_sampler(effective_prompt_config)
        foreign_inspiration_sources = [
            source["source_task"]
            for source in request.foreign_inspirations
            if source.get("source_task")
        ]

        if effective_prompt_config.programs_as_changes_description:
            parent_changes_desc = (
                parent.changes_description or effective_prompt_config.initial_changes_description
            )
            child_changes_desc = parent_changes_desc
        else:
            parent_changes_desc = None
            child_changes_desc = ""

        stage = "prompt construction"
        prompt = prompt_sampler.build_prompt(
            current_program=parent.code,
            parent_program=parent.code,
            program_metrics=parent.metrics,
            previous_programs=request.previous_programs,
            top_programs=request.top_programs,
            inspirations=request.inspirations,
            foreign_inspirations=request.foreign_inspirations,
            language=_worker_config.language,
            evolution_round=request.local_iteration,
            diff_based_evolution=_worker_config.diff_based_evolution,
            program_artifacts=request.parent_artifacts,
            feature_dimensions=request.feature_dimensions,
            current_changes_description=parent_changes_desc,
        )

        stage = "LLM generation"
        generation_start = time.time()
        llm_response = asyncio.run(
            _worker_llm_ensemble.generate_with_context(
                system_message=prompt["system"],
                messages=[{"role": "user", "content": prompt["user"]}],
            )
        )
        generation_time = time.time() - generation_start

        if llm_response is None:
            return _normal_failure(
                request,
                reason="LLM returned no response",
                iteration_start=iteration_start,
                generation_time_sec=generation_time,
                prompt=prompt,
            )

        stage = "candidate parsing"
        try:
            if _worker_config.diff_based_evolution:
                diff_blocks = extract_diffs(llm_response, _worker_config.diff_pattern)
                if not diff_blocks:
                    return _normal_failure(
                        request,
                        reason="no valid diffs found in response",
                        iteration_start=iteration_start,
                        generation_time_sec=generation_time,
                        prompt=prompt,
                        llm_response=llm_response,
                    )

                if _worker_config.prompt.programs_as_changes_description:
                    code_blocks, desc_blocks, _unmatched = split_diffs_by_target(
                        diff_blocks,
                        code_text=parent.code,
                        changes_description_text=parent_changes_desc,
                    )
                    child_code, _ = apply_diff_blocks(parent.code, code_blocks)
                    child_changes_desc, desc_applied = apply_diff_blocks(
                        parent_changes_desc, desc_blocks
                    )
                    if (
                        desc_applied == 0
                        or not child_changes_desc.strip()
                        or child_changes_desc.strip() == parent_changes_desc.strip()
                    ):
                        return _normal_failure(
                            request,
                            reason="changes_description was not updated or is empty",
                            iteration_start=iteration_start,
                            generation_time_sec=generation_time,
                            prompt=prompt,
                            llm_response=llm_response,
                        )

                    changes_summary = format_diff_summary(
                        code_blocks,
                        max_line_len=_worker_config.prompt.diff_summary_max_line_len,
                        max_lines=_worker_config.prompt.diff_summary_max_lines,
                    )
                else:
                    child_code = apply_diff(parent.code, llm_response, _worker_config.diff_pattern)
                    changes_summary = format_diff_summary(
                        diff_blocks,
                        max_line_len=_worker_config.prompt.diff_summary_max_line_len,
                        max_lines=_worker_config.prompt.diff_summary_max_lines,
                    )
            else:
                new_code = parse_full_rewrite(llm_response, _worker_config.language)
                if not new_code:
                    return _normal_failure(
                        request,
                        reason="no valid rewritten program found in response",
                        iteration_start=iteration_start,
                        generation_time_sec=generation_time,
                        prompt=prompt,
                        llm_response=llm_response,
                    )
                child_code = new_code
                changes_summary = "Full rewrite"
        except Exception as exc:
            return _normal_failure(
                request,
                reason=_format_worker_exception(exc),
                iteration_start=iteration_start,
                generation_time_sec=generation_time,
                prompt=prompt,
                llm_response=llm_response,
            )

        if len(child_code) > _worker_config.max_code_length:
            return _normal_failure(
                request,
                reason=(
                    "generated code exceeds max length "
                    f"({len(child_code)} > {_worker_config.max_code_length})"
                ),
                iteration_start=iteration_start,
                generation_time_sec=generation_time,
                prompt=prompt,
                llm_response=llm_response,
            )

        child_program_id = str(uuid.uuid4())
        stage = "evaluation"
        evaluation_start = time.time()
        child_metrics = asyncio.run(_worker_evaluator.evaluate_program(child_code, child_program_id))
        evaluation_time = time.time() - evaluation_start
        artifacts = _worker_evaluator.get_pending_artifacts(child_program_id)

        metadata = {
            "changes": changes_summary,
            "parent_metrics": parent.metrics,
        }
        if foreign_inspiration_sources:
            metadata["foreign_inspiration_sources"] = foreign_inspiration_sources

        child_program = Program(
            id=child_program_id,
            code=child_code,
            changes_description=child_changes_desc,
            language=_worker_config.language,
            parent_id=parent.id,
            generation=parent.generation + 1,
            metrics=child_metrics,
            iteration_found=request.local_iteration,
            metadata=metadata,
        )

        return TaskIterationWorkerResult(
            task_name=request.task_name,
            local_iteration=request.local_iteration,
            success=True,
            target_island=request.target_island,
            rng_state=_safe_capture_rng_state(request.rng_state),
            foreign_inspiration_sources=foreign_inspiration_sources,
            child_program_dict=child_program.to_dict(),
            parent_id=parent.id,
            generation_time_sec=generation_time,
            evaluation_time_sec=evaluation_time,
            iteration_time_sec=time.time() - iteration_start,
            prompt=prompt,
            llm_response=llm_response,
            artifacts=artifacts,
            foreign_transfer_trigger_reason=request.foreign_transfer_trigger_reason,
            chosen_transfer_arm=request.chosen_transfer_arm,
        )
    except Exception as exc:
        raise RuntimeError(f"{stage} failed: {_format_worker_exception(exc)}") from None


class DedicatedTaskWorker:
    """One persistent spawned worker process reserved for a single task."""

    def __init__(
        self,
        *,
        task_name: str,
        config: Config,
        evaluation_file: str,
        task_env: Dict[str, str],
        worker_log_path: Optional[str] = None,
    ):
        self._task_name = task_name
        self._config_dict = serialize_task_config(config)
        self._evaluation_file = evaluation_file
        self._task_env = dict(task_env)
        self._parent_env = dict(os.environ)
        self._max_tasks_per_child = config.max_tasks_per_child
        self._worker_log_path = worker_log_path
        self._log_level_name = config.log_level
        self._executor: Optional[ProcessPoolExecutor] = None

    def start(self) -> None:
        if self._executor is not None:
            return

        executor_kwargs: Dict[str, Any] = {
            "max_workers": 1,
            "initializer": _worker_init,
            "initargs": (
                self._config_dict,
                self._evaluation_file,
                self._parent_env,
                self._task_env,
                self._task_name,
                self._worker_log_path,
                self._log_level_name,
            ),
            "mp_context": mp.get_context("spawn"),
        }
        if sys.version_info >= (3, 11) and self._max_tasks_per_child is not None:
            executor_kwargs["max_tasks_per_child"] = self._max_tasks_per_child

        self._executor = ProcessPoolExecutor(**executor_kwargs)

    def stop(self) -> None:
        if self._executor is None:
            return
        self._executor.shutdown(wait=True, cancel_futures=True)
        self._executor = None

    def submit_initial_program(
        self, request: InitialProgramEvaluationRequest
    ) -> Future[InitialProgramEvaluationResult]:
        if self._executor is None:
            raise RuntimeError("Dedicated multitask worker was not started")
        return self._executor.submit(run_initial_program_evaluation, request)

    def submit_iteration(
        self, request: TaskIterationRequest
    ) -> Future[TaskIterationWorkerResult]:
        if self._executor is None:
            raise RuntimeError("Dedicated multitask worker was not started")
        return self._executor.submit(run_task_iteration, request)
