"""
bppm setup module for pep_runner.py.

Algorithm: Bregman proximal point method with fixed step size.
Performance metric: f(x_N) - f(x_star).
Initial condition: D_h(x_star, x_0) <= R.
Conjectured rate: R / (stepsize * N).
"""

import pepflow as pf

# Module-level objects shared across get_pep_setup calls.
stepsize = pf.Parameter("stepsize")
f = pf.ConvexFunction(is_basis=True, tags=["f"])
h = pf.ConvexFunction(is_basis=True, tags=["h"])


def bregman_distance(kernel, x, y):
    """Return D_kernel(x, y)."""
    return kernel(x) - kernel(y) - kernel.grad(y) * (x - y)


def make_ctx_bppm(ctx_name: str, N, params) -> pf.PEPContext:
    """Build the PEPContext encoding N Bregman proximal point steps."""
    del params

    ctx = pf.PEPContext(ctx_name).set_as_current()
    x = pf.Vector(is_basis=True, tags=["x_0"])
    f.set_stationary_point("x_star")

    for i in range(int(N)):
        x = f.bregman_prox(x, stepsize, h, tag=f"x_{i + 1}")

    return ctx


def get_pep_setup(N, params):
    """Standard interface for pep_runner.py."""
    R = pf.Parameter("R")
    ctx = make_ctx_bppm(f"ctx_{N}", N, params)
    pb = pf.PEPBuilder(ctx)
    pb.add_initial_constraint(
        bregman_distance(h, ctx["x_star"], ctx["x_0"]).le(R, name="initial_condition")
    )
    pb.set_performance_metric(f(ctx[f"x_{N}"]) - f(ctx["x_star"]))
    return ctx, pb, f
