#!/usr/bin/env python
"""
Simple grid search runner for toy_example.

Supports SPACO and MGD on linear/nonlinear problems by emitting Hydra overrides
and invoking `python -m toy_example.main`. No third-party deps required.
"""

from __future__ import annotations

import argparse
import itertools
import os
import subprocess
import sys
from typing import Iterable


def _as_override(key: str, value) -> str:
    """Format a Hydra override."""
    # bool must be lowercase true/false for YAML
    if isinstance(value, bool):
        return f"{key}={'true' if value else 'false'}"
    return f"{key}={value}"


def _spaco_search_space() -> dict[str, dict[str, dict[str, list]]]:
    """Search space for SPACO keyed by problem."""
    params_grid = {
        "rho0": [1.0, 10.0, 50.0],
        "alpha0": [0.1, 0.01, 0.001],
        "beta0": [0.1, 0.01, 0.001],
        "prox0": [0.01, 0.001, 0.0001],
    }
    return {
        "linear": params_grid,
        "nonlinear": params_grid,
    }


def _mgd_search_space() -> dict[str, dict[str, list]]:
    """Search space for MGD keyed by problem."""
    params_grid = {
        "alpha": [0.1, 0.01, 0.001],
        "beta": [0.1, 0.01, 0.001],
        "dual_stepsize": [1.0, 0.1, 0.01],
        "inner_steps": [1, 5, 10],
    }
    return {
        "linear": params_grid,
        "nonlinear": params_grid,
    }


def _rmpdpg_search_space() -> dict[str, dict[str, list]]:
    """Search space for RMPDPG keyed by problem."""
    params_grid = {
        "alpha0": [0.1, 0.01, 0.001],
        "beta": [0.1, 0.01, 0.001],
        "dual_stepsize0": [1.0, 0.1, 0.01],
        "intplation0": [1.0, 0.5, 0.1],
    }
    return {
        "linear": params_grid,
        "nonlinear": params_grid,
    }


def iter_jobs(
    algo: str,
    problem: str,
) -> Iterable[tuple[list[str], dict[str, float | int | bool]]]:
    """
    Yield (overrides, param_dict) for each job of a single algo/problem pair.

    Each job is a list of hydra overrides, e.g.:
    ["problem=linear", "algo=mgd/linear", "algo.alpha=0.001", ...]
    """
    spaco_space = _spaco_search_space()
    mgd_space = _mgd_search_space()
    rmpdpg_space = _rmpdpg_search_space()

    base = [f"problem={problem}"]
    if algo == "spaco":
        space = spaco_space[problem]
        keys = sorted(space)
        for values in itertools.product(*(space[k] for k in keys)):
            overrides = base + [f"algo=spaco/{problem}"]
            params = {k: v for k, v in zip(keys, values)}
            overrides += [_as_override(f"algo.{k}", v) for k, v in params.items()]
            yield overrides, params
    elif algo == "mgd":
        space = mgd_space[problem]
        keys = sorted(space)
        for values in itertools.product(*(space[k] for k in keys)):
            overrides = base + [f"algo=mgd/{problem}"]
            params = {k: v for k, v in zip(keys, values)}
            overrides += [_as_override(f"algo.{k}", v) for k, v in params.items()]
            yield overrides, params
    elif algo == "rmpdpg":
        space = rmpdpg_space[problem]
        keys = sorted(space)
        for values in itertools.product(*(space[k] for k in keys)):
            overrides = base + [f"algo=rmpdpg/{problem}"]
            params = {k: v for k, v in zip(keys, values)}
            overrides += [_as_override(f"algo.{k}", v) for k, v in params.items()]
            yield overrides, params
    else:
        raise ValueError(f"Unsupported algo: {algo}")


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Grid search launcher for toy_example.")
    parser.add_argument(
        "--algo",
        choices=["spaco", "mgd", "rmpdpg"],
        default="spaco",
        help="Which algorithm to sweep.",
    )
    parser.add_argument(
        "--problem",
        choices=["linear", "nonlinear"],
        default="linear",
        help="Which single problem to sweep.",
    )
    parser.add_argument(
        "--python-cmd",
        default=os.environ.get("MMCON_PYTHON", sys.executable),
        help="Python command to invoke (default: current interpreter or $MMCON_PYTHON).",
    )
    parser.add_argument(
        "--use-uv",
        action="store_true",
        help="Use `uv run` to execute instead of plain python.",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Print commands without executing.",
    )
    parser.add_argument(
        "--max-runs",
        type=int,
        default=None,
        help="Optional cap on number of runs (useful for smoke tests).",
    )
    parser.add_argument(
        "--runs",
        type=int,
        default=1,
        help="Repeat the full job list this many times; seeds vary only by run.",
    )
    parser.add_argument(
        "--group-name",
        type=str,
        default=None,
        help="Optional SwanLab group name override (applied to all runs).",
    )
    parser.add_argument(
        "--hydra-run-dir",
        type=str,
        default=None,
        help="Optional hydra.run.dir override (e.g., '.' to avoid new folders).",
    )
    parser.add_argument(
        "--seed-base",
        type=int,
        default=42,
        help="Base seed; each run uses seed_base + run_index (0-based).",
    )
    parser.add_argument(
        "--group-default",
        type=str,
        default=None,
        help="Optional default SwanLab group if --group-name not set; otherwise auto grid-{algo}-{problem}.",
    )
    return parser


def main(argv: list[str] | None = None) -> int:
    parser = build_parser()
    args = parser.parse_args(argv)

    jobs = list(iter_jobs(args.algo, args.problem))
    if args.max_runs is not None:
        jobs = jobs[: args.max_runs]

    if not jobs:
        print("No jobs to run.")
        return 0

    base_cmd: list[str]
    if args.use_uv:
        base_cmd = ["uv", "run", "python", "-m", "toy_example.main"]
    else:
        base_cmd = [args.python_cmd, "-m", "toy_example.main"]

    # Default group if not supplied
    default_group = args.group_default or f"grid-{args.algo}-{args.problem}"

    total_runs = len(jobs) * args.runs
    run_counter = 0

    for run_idx in range(args.runs):
        seed = args.seed_base + run_idx
        for job_idx, (base_overrides, params) in enumerate(jobs, start=1):
            run_counter += 1
            cmd = list(base_cmd)
            overrides = list(base_overrides)

            group_name = args.group_name or default_group
            overrides.append(_as_override("group_name", group_name))
            if args.hydra_run_dir:
                overrides.append(_as_override("hydra.run.dir", args.hydra_run_dir))

            overrides.append(_as_override("problem.seed", seed))

            # Build experiment name from parameters for clarity (quote to allow '=' chars)
            parts = [args.algo.upper(), args.problem]
            parts += [f"{k}={v}" for k, v in sorted(params.items())]
            if args.runs > 1:
                parts.append(f"run={run_idx}")
            exp_name = "|".join(parts)
            # Escape quotes and backslashes for Hydra override
            exp_name_escaped = exp_name.replace("\\", "\\\\").replace('"', '\\"')
            overrides.append(f'experiment_name="{exp_name_escaped}"')

            cmd.extend(overrides)

            prefix = f"[{run_counter}/{total_runs}]"
            print(prefix, " ".join(cmd))

            if args.dry_run:
                continue

            completed = subprocess.run(cmd)
            if completed.returncode != 0:
                print(f"{prefix} failed with code {completed.returncode}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
