"""
Compact scaffold for the single-loop stochastic constrained minimax (SPACO) algorithm.
"""

from __future__ import annotations

import logging

import hydra
import numpy as np
import swanlab
from omegaconf import DictConfig, OmegaConf

Array = np.ndarray

from toy_example.problems import (
    BaseProblem,
    LinearToyExample,
    NonlinearToyExample,
    ToyExampleCfg,
    ToyExampleSol,
)
from toy_example.common import BasicAlgoConfig
from toy_example.algorithms.spaco import (
    SPACO,
    AlgoConfig as SPACOAlgoConfig,
)
from toy_example.algorithms.mgd import (
    MGD,
    AlgoConfig as MGDAlgoConfig,
)
from toy_example.algorithms.rmpdpg import (
    RMPDPG,
    AlgoConfig as RMPDPGAlgoConfig,
)


def make_spaco_config(algo_cfg: DictConfig, seed: int) -> SPACOAlgoConfig:
    pilot_t = algo_cfg.pilot_t
    pilot_s = algo_cfg.pilot_s
    return SPACOAlgoConfig(
        max_iters=algo_cfg.max_iters,
        log_every=algo_cfg.log_every,
        rho=lambda k: algo_cfg.rho0 * k ** (pilot_t),
        alpha=lambda k: algo_cfg.alpha0 * k ** (-6 * pilot_t - pilot_s),
        beta=lambda k: algo_cfg.beta0 * k ** (-pilot_t - pilot_s),
        eta=lambda k: algo_cfg.eta0 * k ** (-pilot_s),
        prox=lambda k: algo_cfg.prox0 * k ** (-pilot_t),
        is_storm=algo_cfg.is_storm,
        seed=seed,
    )


def make_mgd_config(algo_cfg: DictConfig, seed: int) -> MGDAlgoConfig:
    return MGDAlgoConfig(
        max_iters=algo_cfg.max_iters,
        log_every=algo_cfg.log_every,
        seed=seed,
        alpha=algo_cfg.alpha,
        beta=algo_cfg.beta,
        dual_stepsize=algo_cfg.dual_stepsize,
        inner_steps=algo_cfg.inner_steps,
    )


def make_rmpdpg_config(algo_cfg: DictConfig, seed: int) -> RMPDPGAlgoConfig:
    return RMPDPGAlgoConfig(
        max_iters=algo_cfg.max_iters,
        log_every=algo_cfg.log_every,
        seed=seed,
        is_storm=algo_cfg.is_storm,
        alpha=lambda k: algo_cfg.alpha0 * (k + 2) ** (- 4 / 13),
        beta=algo_cfg.beta,
        eta_x=lambda k: algo_cfg.eta0_x * (k + 2) ** (- 12 / 13),
        eta_y=lambda k: algo_cfg.eta0_y * (k + 2) ** (- 8 / 13),
        dual_stepsize=lambda k: algo_cfg.dual_stepsize0 * (k + 2) ** (- 4 / 13),
        intpl=lambda k: algo_cfg.intplation0 * (k + 2) ** (- 5 / 13),
    )


def infer_experiment_name(algo: str) -> str:
    mapping = {
        "spaco": "SPACO",
        "mgd": "MGD",
        "rmpdpg": "RMPDPG",
    }
    if algo not in mapping:
        raise ValueError(f"Unknown algorithm: {algo}")
    return mapping[algo]


# --- Tiny wiring helper using the existing toy example ----------------------------------------
def build_solver(
    *,
    algo_cls,
    cfg: BasicAlgoConfig,
    problem: BaseProblem,
):
    """Construct a solver with box projections from the toy example."""

    def proj_x(v: Array) -> Array:
        return problem.project_x(v)

    def proj_y(v: Array) -> Array:
        return problem.project_y(v)

    def noise_sampler(count: int | None = None) -> Array:
        return problem.sample_noise(count)

    def constraint(x: Array, y: Array) -> Array:
        return problem.constraint_residual(x, y)

    obj_noise_samples = problem.sample_noise()

    def objective_fn(x: Array, y: Array) -> float:
        return problem.stochastic_objective(x, y, noise_samples=obj_noise_samples)

    exact_sol = problem.exact_solution()

    return algo_cls(
        cfg=cfg,
        proj_x=proj_x,
        proj_y=proj_y,
        constraint=constraint,
        grad_y=problem.gradient_y,
        grad_x=problem.gradient_x,
        grad_cx=problem.constraint_gradient_x,
        grad_cy=problem.constraint_gradient_y,
        noise_sampler=noise_sampler,
        objective_fn=objective_fn,
        exact_sol=exact_sol,
    )


def build_problem(problem_cfg: DictConfig) -> BaseProblem:
    cfg = ToyExampleCfg(
        dimension=problem_cfg.dimension,
        noise_level=problem_cfg.noise,
        seed=problem_cfg.seed,
        num_noise_samples=problem_cfg.num_samples,
        noise_pool_refresh_threshold=problem_cfg.noise_refresh,
    )
    if problem_cfg.name == "linear":
        return LinearToyExample(config=cfg)
    if problem_cfg.name == "nonlinear":
        return NonlinearToyExample(config=cfg)
    raise ValueError(f"Unknown problem: {problem_cfg.name}")


@hydra.main(config_path="conf", config_name="config", version_base=None)
def main(cfg: DictConfig) -> None:
    algo_name = cfg.algo.name
    experiment_name = cfg.experiment_name or infer_experiment_name(algo_name)

    swanlab.init(
        project=cfg.project
        if cfg.project is not None
        else f"MinMaxCon-Toy-{cfg.problem.name.title()}",
        experiment_name=experiment_name,
        group=cfg.group_name,
        config=OmegaConf.to_container(cfg, resolve=True),
    )

    logging.basicConfig(
        level=logging.INFO,
        format="[%(levelname)s] %(message)s",
    )

    problem = build_problem(cfg.problem)
    exact_sol = problem.exact_solution()
    with np.printoptions(precision=2, suppress=False):
        logging.info(f"Exact solution:\nx*={exact_sol.x}\ny*={exact_sol.y}")
    init_guess = problem.generate_random_guess()

    if algo_name == "spaco":
        spaco_cfg = make_spaco_config(cfg.algo, seed=cfg.problem.seed)
        spaco_solver = build_solver(algo_cls=SPACO, cfg=spaco_cfg, problem=problem)
        _ = spaco_solver.run(init_guess)

    elif algo_name == "mgd":
        mgd_cfg = make_mgd_config(cfg.algo, seed=cfg.problem.seed)
        mgd_solver = build_solver(algo_cls=MGD, cfg=mgd_cfg, problem=problem)
        _ = mgd_solver.run(init_guess)

    elif algo_name == "rmpdpg":
        rmpdpg_cfg = make_rmpdpg_config(cfg.algo, seed=cfg.problem.seed)
        rmpdpg_solver = build_solver(algo_cls=RMPDPG, cfg=rmpdpg_cfg, problem=problem)
        _ = rmpdpg_solver.run(init_guess)

    else:
        raise ValueError(f"Unknown algorithm: {algo_name}")


if __name__ == "__main__":
    main()
