import importlib.util
from pathlib import Path
import json
import sys

import pytest
import yaml

REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from openevolve.multi_task_shared_then_specialize.sldbench_3d import (
    build_generic_system_prompt_for_sldbench_3d,
    build_task_specific_canonical_system_prompt_for_sldbench_3d,
    get_sldbench_3d_system_prompt,
)
from openevolve.multi_task_shared_then_specialize.trials import build_setting_output_dir_name
from openevolve.multi_task_shared_then_specialize.workflow import (
    build_mt_sts_setting_slug,
    build_phase_wandb_config,
    default_mt_sts_run_prefix,
    family_task_specs,
    fair_mt_sts_baseline_iterations,
    load_manifest,
    resolve_phase_system_prompt,
    write_phase_config,
)


DEFAULT_MANIFEST_PATH = (
    REPO_ROOT / "multi_task_shared_then_adapt" / "sldbench_3d_mt_sts.yaml"
)
TASK_SPECIFIC_CANONICAL_MANIFEST_PATH = (
    REPO_ROOT
    / "multi_task_shared_then_adapt"
    / "sldbench_3d_mt_sts_task_specific_canonical.yaml"
)
BASE_CONFIG_PATH = REPO_ROOT / "examples" / "sldbench_3d_mt_sts" / "config.yaml"


def _load_mt_sts_reporter_module():
    reporter_path = (
        REPO_ROOT / "multi_task_shared_then_adapt" / "report_mt_sts_results.py"
    )
    spec = importlib.util.spec_from_file_location(
        "mt_sts_reporter_test_module_sldbench",
        reporter_path,
    )
    if spec is None or spec.loader is None:
        raise ImportError(f"Could not load MT-STS reporter from {reporter_path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


mt_sts_reporter = _load_mt_sts_reporter_module()


def _read_system_prompt(config_path: Path) -> str:
    raw = yaml.safe_load(config_path.read_text(encoding="utf-8")) or {}
    prompt = raw.get("prompt") or {}
    return str(prompt.get("system_message", "")).strip()


def test_sldbench_3d_manifest_prompt_mode_defaults_stay_generic():
    manifest = load_manifest(DEFAULT_MANIFEST_PATH)

    assert manifest.shared_prompt_mode == "generic"
    assert manifest.adaptation_prompt_mode == "generic"
    assert manifest.baseline_prompt_mode == "generic"


def test_sldbench_3d_task_specific_canonical_manifest_sets_requested_modes():
    manifest = load_manifest(TASK_SPECIFIC_CANONICAL_MANIFEST_PATH)

    assert manifest.shared_prompt_mode == "generic"
    assert manifest.adaptation_prompt_mode == "task_specific_canonical"
    assert manifest.baseline_prompt_mode == "task_specific_canonical"


@pytest.mark.parametrize(
    "manifest_path",
    [DEFAULT_MANIFEST_PATH, TASK_SPECIFIC_CANONICAL_MANIFEST_PATH],
)
def test_sldbench_3d_default_iteration_values_are_fair(manifest_path: Path):
    manifest = load_manifest(manifest_path)

    assert manifest.default_shared_iterations == 14
    assert manifest.default_adaptation_iterations == 10
    assert manifest.default_baseline_iterations == 17
    assert fair_mt_sts_baseline_iterations(
        task_count=2,
        shared_iterations=manifest.default_shared_iterations,
        adaptation_iterations=manifest.default_adaptation_iterations,
    ) == manifest.default_baseline_iterations


def test_manifest_rejects_non_generic_shared_prompt_mode(tmp_path):
    raw_manifest = yaml.safe_load(DEFAULT_MANIFEST_PATH.read_text(encoding="utf-8")) or {}
    raw_manifest["shared_prompt_mode"] = "task_specific_canonical"
    manifest_path = tmp_path / "invalid_shared_prompt_mode.yaml"
    manifest_path.write_text(yaml.safe_dump(raw_manifest, sort_keys=False), encoding="utf-8")

    with pytest.raises(ValueError, match="shared_prompt_mode currently only supports 'generic'"):
        load_manifest(manifest_path)


def test_generic_prompt_builder_matches_default_config_prompt():
    assert build_generic_system_prompt_for_sldbench_3d() == _read_system_prompt(BASE_CONFIG_PATH)


def test_task_specific_canonical_prompt_builder_for_vocab_scaling_law():
    prompt = get_sldbench_3d_system_prompt(
        task_id="vocab_scaling_law",
        prompt_mode="task_specific_canonical",
    )

    assert "specialized for the `vocab_scaling_law` task" in prompt
    assert "- model_size_like = non_vocab_parameters" in prompt
    assert "- diversity_like = vocab_size" in prompt
    assert "- total_data_like = num_characters" in prompt
    assert "- unigram_normalized_loss" in prompt


def test_task_specific_canonical_prompt_builder_for_data_constrained_scaling_law():
    prompt = build_task_specific_canonical_system_prompt_for_sldbench_3d(
        "data_constrained_scaling_law"
    )

    assert "specialized for the `data_constrained_scaling_law` task" in prompt
    assert "- model_size_like = params" in prompt
    assert "- diversity_like = unique_tokens" in prompt
    assert "- total_data_like = tokens" in prompt
    assert "- loss" in prompt


def test_task_specific_canonical_prompt_builder_rejects_missing_or_unknown_task():
    with pytest.raises(ValueError, match="requires a concrete SLDBench 3D task_id"):
        get_sldbench_3d_system_prompt(prompt_mode="task_specific_canonical")

    with pytest.raises(ValueError, match="Unknown SLDBench 3D task"):
        build_task_specific_canonical_system_prompt_for_sldbench_3d("unknown_task")


def test_adaptation_config_generation_uses_generic_prompt_by_default(tmp_path):
    manifest = load_manifest(DEFAULT_MANIFEST_PATH)
    output_config = tmp_path / "adaptation_generic.yaml"

    write_phase_config(
        base_config_path=manifest.base_config,
        output_config_path=output_config,
        iterations=3,
        system_prompt=resolve_phase_system_prompt(
            manifest,
            phase="adaptation",
            task_id="vocab_scaling_law",
        ),
    )

    assert _read_system_prompt(output_config) == build_generic_system_prompt_for_sldbench_3d()


def test_adaptation_config_generation_uses_task_specific_canonical_prompt_when_requested(
    tmp_path,
):
    manifest = load_manifest(TASK_SPECIFIC_CANONICAL_MANIFEST_PATH)
    output_config = tmp_path / "adaptation_vocab_canonical.yaml"

    write_phase_config(
        base_config_path=manifest.base_config,
        output_config_path=output_config,
        iterations=3,
        system_prompt=resolve_phase_system_prompt(
            manifest,
            phase="adaptation",
            task_id="vocab_scaling_law",
        ),
    )

    assert _read_system_prompt(output_config) == (
        build_task_specific_canonical_system_prompt_for_sldbench_3d("vocab_scaling_law")
    )


def test_baseline_config_generation_uses_generic_prompt_by_default(tmp_path):
    manifest = load_manifest(DEFAULT_MANIFEST_PATH)
    output_config = tmp_path / "baseline_generic.yaml"

    write_phase_config(
        base_config_path=manifest.base_config,
        output_config_path=output_config,
        iterations=3,
        system_prompt=resolve_phase_system_prompt(
            manifest,
            phase="baseline",
            task_id="data_constrained_scaling_law",
        ),
    )

    assert _read_system_prompt(output_config) == build_generic_system_prompt_for_sldbench_3d()


def test_baseline_config_generation_uses_task_specific_canonical_prompt_when_requested(
    tmp_path,
):
    manifest = load_manifest(TASK_SPECIFIC_CANONICAL_MANIFEST_PATH)
    output_config = tmp_path / "baseline_data_canonical.yaml"

    write_phase_config(
        base_config_path=manifest.base_config,
        output_config_path=output_config,
        iterations=3,
        system_prompt=resolve_phase_system_prompt(
            manifest,
            phase="baseline",
            task_id="data_constrained_scaling_law",
        ),
    )

    assert _read_system_prompt(output_config) == (
        build_task_specific_canonical_system_prompt_for_sldbench_3d(
            "data_constrained_scaling_law"
        )
    )


def test_default_naming_stays_stable_for_all_generic_manifest(tmp_path):
    manifest = load_manifest(DEFAULT_MANIFEST_PATH)

    assert default_mt_sts_run_prefix(base_prefix="mt_sts", manifest=manifest) == "mt_sts"
    assert build_mt_sts_setting_slug(
        shared_iterations=15,
        adaptation_iterations=10,
        baseline_iterations=12,
        shared_prompt_mode=manifest.shared_prompt_mode,
        adaptation_prompt_mode=manifest.adaptation_prompt_mode,
        baseline_prompt_mode=manifest.baseline_prompt_mode,
    ) == "s15-a10-b12"

    wandb_config = build_phase_wandb_config(
        manifest,
        run_name="mt_sts_20260407_000000",
        run_root=tmp_path / "run",
        phase="adaptation",
        task_id="vocab_scaling_law",
        shared_iterations=15,
        adaptation_iterations=10,
        baseline_iterations=12,
    )
    assert "adaptcanon" not in wandb_config["name"]
    assert "basecanon" not in wandb_config["name"]
    assert "shared-generic" in wandb_config["tags"]


def test_task_specific_canonical_modes_add_distinct_suffixes_to_names_and_tags(tmp_path):
    manifest = load_manifest(TASK_SPECIFIC_CANONICAL_MANIFEST_PATH)

    assert default_mt_sts_run_prefix(base_prefix="mt_sts", manifest=manifest) == (
        "mt_sts-adaptcanon-basecanon"
    )
    assert build_mt_sts_setting_slug(
        shared_iterations=15,
        adaptation_iterations=10,
        baseline_iterations=12,
        shared_prompt_mode=manifest.shared_prompt_mode,
        adaptation_prompt_mode=manifest.adaptation_prompt_mode,
        baseline_prompt_mode=manifest.baseline_prompt_mode,
    ) == "s15-a10-b12-adaptcanon-basecanon"
    assert build_setting_output_dir_name(
        shared_iterations=15,
        adaptation_iterations=10,
        baseline_iterations=12,
        shared_prompt_mode=manifest.shared_prompt_mode,
        adaptation_prompt_mode=manifest.adaptation_prompt_mode,
        baseline_prompt_mode=manifest.baseline_prompt_mode,
        primary_model="claude-sonnet-4-6",
        edit_mode="full",
    ) == "s15-a10-b12-adaptcanon-basecanon-claude-sonnet-4-6-full"

    wandb_config = build_phase_wandb_config(
        manifest,
        run_name="mt_sts-adaptcanon-basecanon_20260407_000000",
        run_root=tmp_path / "run",
        phase="adaptation",
        task_id="vocab_scaling_law",
        shared_iterations=15,
        adaptation_iterations=10,
        baseline_iterations=12,
    )
    assert "adaptcanon-basecanon" in wandb_config["name"]
    assert "shared-generic" in wandb_config["tags"]
    assert "adaptcanon" in wandb_config["tags"]
    assert "basecanon" in wandb_config["tags"]
    assert wandb_config["shared_prompt_mode"] == "generic"
    assert wandb_config["adaptation_prompt_mode"] == "task_specific_canonical"
    assert wandb_config["baseline_prompt_mode"] == "task_specific_canonical"
    assert (
        "Prompt modes: shared=generic, adaptation=task_specific_canonical, "
        "baseline=task_specific_canonical."
    ) in wandb_config["notes"]


def test_reporter_separates_sldbench_prompt_mode_variants(tmp_path):
    generic_manifest = load_manifest(DEFAULT_MANIFEST_PATH)
    task_specific_manifest = load_manifest(TASK_SPECIFIC_CANONICAL_MANIFEST_PATH)
    task_specs = family_task_specs(generic_manifest)

    def build_run(
        setting_dir_name: str,
        run_name: str,
        *,
        manifest,
    ) -> dict[str, str]:
        run_root = tmp_path / setting_dir_name / run_name
        configs_root = run_root / "configs"
        configs_root.mkdir(parents=True, exist_ok=True)
        wandb_config = build_phase_wandb_config(
            manifest,
            run_name=run_name,
            run_root=run_root,
            phase="shared",
            shared_iterations=40,
            adaptation_iterations=15,
            baseline_iterations=35,
        )
        (configs_root / "shared_config.yaml").write_text(
            yaml.safe_dump(
                {
                    "max_iterations": 40,
                    "diff_based_evolution": False,
                    "llm": {"primary_model": "claude-sonnet-4-6"},
                    "wandb": wandb_config,
                },
                sort_keys=False,
            ),
            encoding="utf-8",
        )
        for task in task_specs:
            for prefix, iterations in (("adaptation", 15), ("baseline", 35)):
                (configs_root / f"{prefix}_{task.task_id}.yaml").write_text(
                    yaml.safe_dump({"max_iterations": iterations}, sort_keys=False),
                    encoding="utf-8",
                )
        summary_payload = {
            "shared_prompt_mode": manifest.shared_prompt_mode,
            "adaptation_prompt_mode": manifest.adaptation_prompt_mode,
            "baseline_prompt_mode": manifest.baseline_prompt_mode,
        }
        (run_root / "comparison_summary.json").write_text(
            json.dumps(summary_payload),
            encoding="utf-8",
        )
        return mt_sts_reporter.load_run_report(
            run_root,
            repo_root=REPO_ROOT,
            manifest_path=DEFAULT_MANIFEST_PATH,
            manifest_family="sldbench_3d",
            task_specs=task_specs,
            wandb_entity_override=None,
        )

    generic_run = build_run(
        "s40-a15-b35-claude-sonnet-4-6-full",
        "run_01_seed_42",
        manifest=generic_manifest,
    )
    task_specific_run = build_run(
        "s40-a15-b35-adaptcanon-basecanon-claude-sonnet-4-6-full",
        "run_01_seed_42",
        manifest=task_specific_manifest,
    )

    assert generic_run["setting"]["id"] != task_specific_run["setting"]["id"]
    assert (
        generic_run["setting"]["adaptation_prompt_mode"],
        generic_run["setting"]["baseline_prompt_mode"],
    ) == ("generic", "generic")
    assert (
        task_specific_run["setting"]["adaptation_prompt_mode"],
        task_specific_run["setting"]["baseline_prompt_mode"],
    ) == ("task_specific_canonical", "task_specific_canonical")
    assert "prompts=" not in generic_run["setting"]["label"]
    assert "prompts=" not in task_specific_run["setting"]["label"]
    assert mt_sts_reporter.is_generic_prompt_setting(generic_run["setting"])
    assert not mt_sts_reporter.is_generic_prompt_setting(task_specific_run["setting"])
