"""
ogm setup module for pep_runner.py.

Algorithm: Optimized gradient method with fixed step size 1 / L.
Performance metric: f(x_N) - f(x_star).
Initial condition: ||x_0 - x_star|| <= R, where grad f(x_star) = 0.
Conjectured rate: L * R^2 / (2 * theta_N^2).
"""

from __future__ import annotations

import functools

import numpy as np

import pepflow as pf

# Module-level objects shared across get_pep_setup calls.
L = pf.Parameter("L")
f = pf.SmoothConvexFunction(is_basis=True, tags=["f"], L=L)


@functools.cache
def theta_ogm(i: int, N: int) -> float:
    """Return the OGM coefficient theta_i for a fixed horizon N."""
    if i == -1:
        return 0.0
    if i == N:
        return 0.5 * (1 + np.sqrt(8 * theta_ogm(N - 1, N) ** 2 + 1))
    return 0.5 * (1 + np.sqrt(4 * theta_ogm(i - 1, N) ** 2 + 1))


def make_ctx_ogm(ctx_name: str, N, params) -> pf.PEPContext:
    """Build the PEPContext encoding N steps of OGM."""
    del params

    N_int = int(N)
    ctx = pf.PEPContext(ctx_name).set_as_current()
    x = pf.Vector(is_basis=True, tags=["x_0"])
    z = x
    f.set_stationary_point("x_star")

    for k in range(N_int):
        theta_k = theta_ogm(k, N_int)
        theta_next = theta_ogm(k + 1, N_int)
        grad_x = f.grad(x)
        y = x - (1 / L) * grad_x
        y.add_tag(f"y_{k}")
        z = z - (2 / L) * theta_k * grad_x
        z.add_tag(f"z_{k + 1}")
        x = (1 - 1 / theta_next) * y + (1 / theta_next) * z
        x.add_tag(f"x_{k + 1}")

    return ctx


def get_pep_setup(N, params):
    """Standard interface for pep_runner.py."""
    R = pf.Parameter("R")
    ctx = make_ctx_ogm(f"ctx_{N}", N, params)
    pb = pf.PEPBuilder(ctx)
    pb.add_initial_constraint(
        ((ctx["x_0"] - ctx["x_star"]) ** 2).le(R**2, name="initial_condition")
    )
    pb.set_performance_metric(f(ctx[f"x_{int(N)}"]) - f(ctx["x_star"]))
    return ctx, pb, f
