#!/usr/bin/env python3
"""Spawn task-local checkpoints from an MT-STS shared checkpoint."""

from __future__ import annotations

import argparse
from pathlib import Path
import sys

REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

DEFAULT_MANIFEST = "multi_task_shared_then_adapt/r_robust_regression_mt_sts.yaml"
CIRCLE_PACKING_MANIFEST = (
    "multi_task_shared_then_adapt/circle_packing_mt_sts.yaml"
)
CIRCLE_PACKING_RECTANGLE_MANIFEST = (
    "multi_task_shared_then_adapt/circle_packing_rectangle_mt_sts.yaml"
)
K_MODULE_MANIFEST = "multi_task_shared_then_adapt/k_module_problem_mt_sts.yaml"
K_MODULE_BALANCED_MANIFEST = (
    "multi_task_shared_then_adapt/k_module_problem_balanced_mt_sts.yaml"
)
FUNCTION_MINIMIZATION_MANIFEST = (
    "multi_task_shared_then_adapt/function_minimization_mt_sts.yaml"
)
HEILBRONN_TRIANGLE_MANIFEST = (
    "multi_task_shared_then_adapt/heilbronn_triangle_mt_sts.yaml"
)
HEXAGON_PACKING_MANIFEST = (
    "multi_task_shared_then_adapt/hexagon_packing_mt_sts.yaml"
)
SIGNAL_PROCESSING_MANIFEST = (
    "multi_task_shared_then_adapt/signal_processing_mt_sts.yaml"
)
SYMBOLIC_REGRESSION_PHYS_OSC_MANIFEST = (
    "multi_task_shared_then_adapt/symbolic_regression_phys_osc_mt_sts.yaml"
)
SLDBENCH_3D_MANIFEST = (
    "multi_task_shared_then_adapt/sldbench_3d_mt_sts.yaml"
)
RUST_ADAPTIVE_SORT_MANIFEST = (
    "multi_task_shared_then_adapt/rust_adaptive_sort_mt_sts.yaml"
)

from openevolve.multi_task_shared_then_specialize.runner import write_json
from openevolve.multi_task_shared_then_specialize.spawn import spawn_task_checkpoints
from openevolve.multi_task_shared_then_specialize.workflow import family_task_specs, load_manifest


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "Spawn task checkpoints from a shared MT-STS checkpoint. "
            f"Defaults to the robust-regression manifest; use --manifest {CIRCLE_PACKING_MANIFEST} "
            f"for unit-square circle packing, --manifest {CIRCLE_PACKING_RECTANGLE_MANIFEST} "
            f"for rectangle circle packing, --manifest {K_MODULE_MANIFEST} "
            f"for the easier K-module family, --manifest {K_MODULE_BALANCED_MANIFEST} "
            f"for the harder balanced K-module family, or --manifest "
            f"{FUNCTION_MINIMIZATION_MANIFEST} for function minimization, or "
            f"--manifest {HEILBRONN_TRIANGLE_MANIFEST} for Heilbronn triangle, or "
            f"--manifest {HEXAGON_PACKING_MANIFEST} for hexagon packing, or "
            f"--manifest {SIGNAL_PROCESSING_MANIFEST} for signal processing, or "
            f"--manifest {SYMBOLIC_REGRESSION_PHYS_OSC_MANIFEST} for symbolic regression "
            "physics oscillators, or "
            f"--manifest {SLDBENCH_3D_MANIFEST} for SLDBench 3D scaling laws, or "
            f"--manifest {RUST_ADAPTIVE_SORT_MANIFEST} for Rust adaptive sort."
        )
    )
    parser.add_argument(
        "--manifest",
        default=DEFAULT_MANIFEST,
        help=(
            "Path to the MT-STS manifest. "
            f"Default: {DEFAULT_MANIFEST}. "
            f"Unit-square circle packing: {CIRCLE_PACKING_MANIFEST}. "
            f"Rectangle circle packing: {CIRCLE_PACKING_RECTANGLE_MANIFEST}. "
            f"Easier K-module: {K_MODULE_MANIFEST}. "
            f"Balanced K-module: {K_MODULE_BALANCED_MANIFEST}. "
            f"Function minimization: {FUNCTION_MINIMIZATION_MANIFEST}. "
            f"Heilbronn triangle: {HEILBRONN_TRIANGLE_MANIFEST}. "
            f"Hexagon packing: {HEXAGON_PACKING_MANIFEST}. "
            f"Signal processing: {SIGNAL_PROCESSING_MANIFEST}. "
            f"Symbolic regression physics oscillators: {SYMBOLIC_REGRESSION_PHYS_OSC_MANIFEST}. "
            f"SLDBench 3D: {SLDBENCH_3D_MANIFEST}. "
            f"Rust adaptive sort: {RUST_ADAPTIVE_SORT_MANIFEST}."
        ),
    )
    parser.add_argument(
        "--shared-checkpoint",
        required=True,
        help="Path to the shared checkpoint directory",
    )
    parser.add_argument(
        "--output-root",
        required=True,
        help="Directory where spawned task checkpoints should be written",
    )
    return parser.parse_args()


def main() -> int:
    args = parse_args()
    manifest = load_manifest(args.manifest)
    task_ids = [task.task_id for task in family_task_specs(manifest)]
    output_root = Path(args.output_root).resolve()

    spawn_results = spawn_task_checkpoints(
        shared_checkpoint_path=args.shared_checkpoint,
        output_root=output_root,
        base_config_path=manifest.base_config,
        evaluation_file=manifest.evaluation_file,
        family=manifest.family,
        task_ids=task_ids,
        initial_program=manifest.initial_program,
    )

    summary_path = write_json(output_root / "spawn_summary.json", spawn_results)
    print(f"Spawned checkpoints written under {output_root}")
    print(f"Spawn summary written to {summary_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
