"""Configuration loading for OpenEvolve multitask runs."""

from __future__ import annotations

from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional

import yaml

from openevolve.config import Config, WandbConfig, deep_merge_dicts, load_config


def _resolve_path(path_value: str, base_dir: Path) -> str:
    path = Path(path_value)
    if path.is_absolute():
        return str(path)
    return str((base_dir / path).resolve())


@dataclass
class SchedulerConfig:
    """Configuration for multitask scheduling."""

    strategy: str = "round_robin"
    weights: Dict[str, float] = field(default_factory=dict)


@dataclass
class ForeignInspirationsPromptOverridesConfig:
    """Optional prompt-budget overrides applied only on transfer iterations."""

    num_top_programs: Optional[int] = None
    num_diverse_programs: Optional[int] = None
    num_local_inspirations: Optional[int] = None


@dataclass
class ForeignInspirationsConfig:
    """Configuration for prompt-only cross-task inspirations."""

    enabled: bool = True
    trigger_mode: str = "periodic"
    every_n_task_iterations: int = 4
    warmup_task_iterations: int = 0
    max_related_tasks: int = 1
    top_programs_per_related_task: int = 1
    include_optional_relation_text: bool = False
    include_scores: bool = True
    include_code: bool = True
    stagnation_patience: int = 6
    transfer_cooldown: int = 4
    min_best_fitness_improvement: float = 1.0e-4
    min_pulls_per_arm: int = 2
    bandit_decay: float = 1.0
    reward_mode: str = "sparse"
    reward_window: int = 5
    reward_margin: float = 0.0
    prompt_overrides: Optional[ForeignInspirationsPromptOverridesConfig] = None

    @property
    def min_improvement(self) -> float:
        """Backward-compatible alias for older configs and call sites."""
        return self.min_best_fitness_improvement

    @min_improvement.setter
    def min_improvement(self, value: float) -> None:
        self.min_best_fitness_improvement = value


@dataclass
class RelatedTaskConfig:
    """Relationship metadata for a task pair."""

    source_task: str
    enabled: bool = True
    prompt_context: Optional[str] = None


@dataclass
class TaskConfig:
    """Per-task configuration within a multitask run."""

    name: str
    initial_program: str
    evaluation_file: str
    output_subdir: Optional[str] = None
    env: Dict[str, str] = field(default_factory=dict)
    config_overrides: Dict[str, Any] = field(default_factory=dict)
    related_tasks: List[RelatedTaskConfig] = field(default_factory=list)


@dataclass
class MultitaskConfig:
    """Root configuration for a multitask OpenEvolve run."""

    output_dir: str = "results/multitask_run"
    execution_mode: str = "sequential_round_robin"
    max_global_iterations: int = 100
    checkpoint_interval: int = 10
    max_waves: Optional[int] = None
    checkpoint_every_waves: Optional[int] = None
    scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
    foreign_inspirations: ForeignInspirationsConfig = field(
        default_factory=ForeignInspirationsConfig
    )
    wandb: Optional[WandbConfig] = None
    base_config_path: Optional[str] = None
    base_config: Dict[str, Any] = field(default_factory=dict)
    tasks: List[TaskConfig] = field(default_factory=list)
    config_path: Optional[str] = None
    config_dir: Optional[str] = None

    @classmethod
    def from_yaml(cls, path: str | Path) -> "MultitaskConfig":
        config_path = Path(path).resolve()
        with open(config_path, "r") as handle:
            raw_config = yaml.safe_load(handle) or {}
        return cls.from_dict(raw_config, config_path=config_path)

    @classmethod
    def from_dict(
        cls, config_dict: Dict[str, Any], config_path: Optional[str | Path] = None
    ) -> "MultitaskConfig":
        root = config_dict.get("multitask", config_dict)
        if not isinstance(root, dict):
            raise ValueError("Multitask configuration must be a mapping")

        base_dir = Path(config_path).resolve().parent if config_path else Path.cwd()

        scheduler = SchedulerConfig(**(root.get("scheduler") or {}))
        foreign_inspirations_dict = dict(root.get("foreign_inspirations") or {})
        legacy_require_stagnation = foreign_inspirations_dict.pop("require_stagnation", None)
        if legacy_require_stagnation is not None:
            explicit_trigger_mode = foreign_inspirations_dict.get("trigger_mode")
            if legacy_require_stagnation:
                if explicit_trigger_mode and explicit_trigger_mode != "stagnation":
                    raise ValueError(
                        "multitask.foreign_inspirations.require_stagnation=true conflicts with "
                        "trigger_mode='periodic'"
                    )
                foreign_inspirations_dict["trigger_mode"] = "stagnation"
            elif explicit_trigger_mode is None:
                foreign_inspirations_dict["trigger_mode"] = "periodic"

        legacy_min_improvement = foreign_inspirations_dict.pop("min_improvement", None)
        if legacy_min_improvement is not None:
            explicit_min_improvement = foreign_inspirations_dict.get(
                "min_best_fitness_improvement"
            )
            if explicit_min_improvement is not None and explicit_min_improvement != legacy_min_improvement:
                raise ValueError(
                    "multitask.foreign_inspirations.min_improvement conflicts "
                    "with min_best_fitness_improvement"
                )
            foreign_inspirations_dict["min_best_fitness_improvement"] = legacy_min_improvement

        prompt_overrides_dict = foreign_inspirations_dict.get("prompt_overrides")
        if prompt_overrides_dict is not None:
            if not isinstance(prompt_overrides_dict, dict):
                raise ValueError(
                    "multitask.foreign_inspirations.prompt_overrides must be a mapping"
                )
            foreign_inspirations_dict["prompt_overrides"] = (
                ForeignInspirationsPromptOverridesConfig(**prompt_overrides_dict)
            )

        foreign_inspirations = ForeignInspirationsConfig(**foreign_inspirations_dict)

        tasks: List[TaskConfig] = []
        task_names: set[str] = set()
        for task_dict in root.get("tasks") or []:
            related_tasks = [
                RelatedTaskConfig(**related_task)
                for related_task in (task_dict.get("related_tasks") or [])
            ]
            task = TaskConfig(
                name=task_dict["name"],
                initial_program=_resolve_path(task_dict["initial_program"], base_dir),
                evaluation_file=_resolve_path(task_dict["evaluation_file"], base_dir),
                output_subdir=task_dict.get("output_subdir") or task_dict["name"],
                env={key: str(value) for key, value in (task_dict.get("env") or {}).items()},
                config_overrides=task_dict.get("config_overrides") or {},
                related_tasks=related_tasks,
            )
            if task.name in task_names:
                raise ValueError(f"Duplicate multitask task name: {task.name}")
            task_names.add(task.name)
            tasks.append(task)

        base_config_path = root.get("base_config_path")
        if base_config_path:
            base_config_path = _resolve_path(base_config_path, base_dir)

        config = cls(
            output_dir=_resolve_path(root.get("output_dir", "results/multitask_run"), base_dir),
            execution_mode=root.get("execution_mode", "sequential_round_robin"),
            max_global_iterations=root.get("max_global_iterations", 100),
            checkpoint_interval=root.get("checkpoint_interval", 10),
            max_waves=root.get("max_waves"),
            checkpoint_every_waves=root.get("checkpoint_every_waves"),
            scheduler=scheduler,
            foreign_inspirations=foreign_inspirations,
            wandb=WandbConfig(**(root["wandb"])) if "wandb" in root else None,
            base_config_path=base_config_path,
            base_config=root.get("base_config") or {},
            tasks=tasks,
            config_path=str(Path(config_path).resolve()) if config_path else None,
            config_dir=str(base_dir),
        )
        config.validate()
        return config

    def validate(self) -> None:
        """Validate the multitask configuration."""
        if not self.tasks:
            raise ValueError("Multitask configuration must define at least one task")

        if self.execution_mode not in {
            "sequential_round_robin",
            "parallel_synchronized_waves",
        }:
            raise ValueError(
                "multitask.execution_mode must be one of "
                "'sequential_round_robin' or 'parallel_synchronized_waves'"
            )

        if self.scheduler.strategy != "round_robin":
            raise ValueError(
                "Only scheduler.strategy='round_robin' is supported in multitask mode v1"
            )

        if self.execution_mode == "sequential_round_robin":
            if self.max_global_iterations <= 0:
                raise ValueError("multitask.max_global_iterations must be greater than zero")

            if self.checkpoint_interval <= 0:
                raise ValueError("multitask.checkpoint_interval must be greater than zero")
        else:
            if self.max_waves is None or self.max_waves <= 0:
                raise ValueError(
                    "multitask.max_waves must be greater than zero when "
                    "execution_mode='parallel_synchronized_waves'"
                )
            if self.checkpoint_every_waves is None or self.checkpoint_every_waves <= 0:
                raise ValueError(
                    "multitask.checkpoint_every_waves must be greater than zero when "
                    "execution_mode='parallel_synchronized_waves'"
                )

        fi_cfg = self.foreign_inspirations
        if fi_cfg.trigger_mode not in {"periodic", "stagnation", "online_bandit"}:
            raise ValueError(
                "multitask.foreign_inspirations.trigger_mode must be one of "
                "'periodic', 'stagnation', or 'online_bandit'"
            )
        if fi_cfg.prompt_overrides is not None:
            for field_name in (
                "num_top_programs",
                "num_diverse_programs",
                "num_local_inspirations",
            ):
                value = getattr(fi_cfg.prompt_overrides, field_name)
                if value is None:
                    continue
                if isinstance(value, bool) or not isinstance(value, int) or value < 0:
                    raise ValueError(
                        "multitask.foreign_inspirations.prompt_overrides."
                        f"{field_name} must be a non-negative integer"
                    )
        if fi_cfg.reward_mode not in {"sparse", "rich"}:
            raise ValueError(
                "multitask.foreign_inspirations.reward_mode must be one of "
                "'sparse' or 'rich'"
            )
        if fi_cfg.reward_window < 1:
            raise ValueError(
                "multitask.foreign_inspirations.reward_window must be at least one"
            )
        if fi_cfg.reward_margin < 0:
            raise ValueError(
                "multitask.foreign_inspirations.reward_margin cannot be negative"
            )
        if fi_cfg.warmup_task_iterations < 0:
            raise ValueError(
                "multitask.foreign_inspirations.warmup_task_iterations cannot be negative"
            )
        if fi_cfg.max_related_tasks < 0:
            raise ValueError(
                "multitask.foreign_inspirations.max_related_tasks cannot be negative"
            )
        if fi_cfg.top_programs_per_related_task <= 0:
            raise ValueError(
                "multitask.foreign_inspirations.top_programs_per_related_task must be greater than zero"
            )
        if fi_cfg.trigger_mode == "periodic":
            if fi_cfg.every_n_task_iterations <= 0:
                raise ValueError(
                    "multitask.foreign_inspirations.every_n_task_iterations must be greater than zero"
                )
        else:
            if fi_cfg.stagnation_patience < 1:
                raise ValueError(
                    "multitask.foreign_inspirations.stagnation_patience must be at least one"
                )
            if fi_cfg.transfer_cooldown < 0:
                raise ValueError(
                    "multitask.foreign_inspirations.transfer_cooldown cannot be negative"
                )
            if fi_cfg.min_best_fitness_improvement < 0:
                raise ValueError(
                    "multitask.foreign_inspirations.min_best_fitness_improvement cannot be negative"
                )
            if fi_cfg.trigger_mode == "online_bandit":
                if fi_cfg.min_pulls_per_arm < 1:
                    raise ValueError(
                        "multitask.foreign_inspirations.min_pulls_per_arm must be at least one"
                    )
                if not (0 < fi_cfg.bandit_decay <= 1):
                    raise ValueError(
                        "multitask.foreign_inspirations.bandit_decay must be in the interval (0, 1]"
                    )
                if fi_cfg.max_related_tasks != 1:
                    raise ValueError(
                        "multitask.foreign_inspirations.max_related_tasks must equal 1 when "
                        "trigger_mode='online_bandit'"
                    )

        task_names = {task.name for task in self.tasks}
        output_subdirs = set()
        for task in self.tasks:
            if task.output_subdir in output_subdirs:
                raise ValueError(
                    f"Duplicate multitask task output_subdir: {task.output_subdir}"
                )
            output_subdirs.add(task.output_subdir)
            for related_task in task.related_tasks:
                if related_task.source_task not in task_names:
                    raise ValueError(
                        f"Task '{task.name}' references unknown related task '{related_task.source_task}'"
                    )
                if related_task.source_task == task.name:
                    raise ValueError(
                        f"Task '{task.name}' cannot list itself as a related task"
                    )


def load_multitask_config(path: str | Path) -> MultitaskConfig:
    """Load a multitask configuration from YAML."""
    return MultitaskConfig.from_yaml(path)


def load_base_task_config(multitask_config: MultitaskConfig) -> Config:
    """Load the shared base OpenEvolve config for a multitask run."""
    base_config = (
        load_config(multitask_config.base_config_path)
        if multitask_config.base_config_path
        else load_config(None)
    )

    if multitask_config.base_config:
        base_config = derive_task_config(
            base_config=base_config,
            overrides=multitask_config.base_config,
            config_dir=Path(multitask_config.config_dir or "."),
        )

    return base_config


def derive_task_config(
    base_config: Config,
    overrides: Dict[str, Any],
    config_dir: str | Path,
) -> Config:
    """Derive a normal OpenEvolve Config for a single task."""
    merged = deep_merge_dicts(base_config.to_dict(), overrides)
    task_config = Config.from_dict(merged)

    # If a task overrides the top-level random_seed but does not explicitly override the
    # database seed, keep ProgramDatabase sampling aligned with the task's effective seed.
    database_overrides = overrides.get("database")
    has_database_seed_override = (
        isinstance(database_overrides, dict) and "random_seed" in database_overrides
    )
    if "random_seed" in overrides and not has_database_seed_override:
        task_config.database.random_seed = task_config.random_seed

    template_dir = task_config.prompt.template_dir
    if template_dir:
        template_path = Path(template_dir)
        if not template_path.is_absolute():
            task_config.prompt.template_dir = str((Path(config_dir) / template_path).resolve())

    # Keep model-level system messages aligned with the effective prompt configuration.
    task_config.llm.update_model_params({"system_message": task_config.prompt.system_message})
    return task_config
