"""
fast_extragradient setup module for pep_runner.py.

Algorithm: Fast Extragradient with fixed step size 1 / L.
Performance metric: ||A(x_N)||^2.
Initial condition: ||x_0 - x_star||^2 <= R^2, where A(x_star) = 0.
Conjectured rate: unknown.
"""

import itertools

import attrs

import pepflow as pf

L = pf.Parameter("L")
alpha = 1 / L


@attrs.frozen(kw_only=True, repr=False)
class LipschitzMonotoneOperator(pf.MonotoneOperator):
    """Monotone operator with an L-Lipschitz interpolation constraint."""

    L: object

    def __hash__(self):
        return super().__hash__()

    def lipschitz_inequality_constraints(self, duplet_i, duplet_j):
        return (
            (duplet_i.output - duplet_j.output) ** 2
            - (self.L**2) * (duplet_i.point - duplet_j.point) ** 2
        ).le(
            0,
            name=f"{self.tag}_Lipschitz:{duplet_i.point.tag},{duplet_j.point.tag}",
        )

    def get_interpolation_constraints_by_group(self, pep_context=None):
        cd = super().get_interpolation_constraints_by_group(pep_context)
        if pep_context is None:
            pep_context = pf.get_current_context()
        if pep_context is None:
            raise RuntimeError("Did you forget to create a context?")

        ordered_duplets = [
            pep_context.get_duplet_by_point_tag(point.tag, self)
            for point in pep_context.tracked_point(self)
        ]
        lipschitz_constraints = [
            self.lipschitz_inequality_constraints(i, j)
            for i, j in itertools.combinations(ordered_duplets, 2)
        ]
        cd.add_sc_constraint("Lipschitz Operator Inequality", lipschitz_constraints)
        return cd


A = LipschitzMonotoneOperator(is_basis=True, tags=["A"], L=L)


def make_ctx_fast_extragradient(ctx_name: str, N, params) -> pf.PEPContext:
    """Build the PEPContext encoding N steps of Fast Extragradient."""
    del params

    ctx = pf.PEPContext(ctx_name).set_as_current()
    x0 = pf.Vector(is_basis=True, tags=["x_0"])
    x0.add_tag("x_0_half")
    x = x0
    A.set_zero_point("x_star")

    for k in range(int(N)):
        Ax = A(x)
        if k == 0:
            x_half = x0
        else:
            x_half = x + (1 / (k + 1)) * (x0 - x) - (k / (k + 1)) * alpha * Ax
            x_half.add_tag(f"x_{k}_half")

        x = x + (1 / (k + 1)) * (x0 - x) - alpha * A(x_half)
        x.add_tag(f"x_{k + 1}")

    return ctx


def get_pep_setup(N, params):
    """Standard interface for pep_runner.py."""
    R = pf.Parameter("R")
    ctx = make_ctx_fast_extragradient(f"ctx_{N}", N, params)
    pb = pf.PEPBuilder(ctx)
    pb.add_initial_constraint(
        ((ctx["x_0"] - ctx["x_star"]) ** 2).le(R**2, name="initial_condition")
    )
    pb.set_performance_metric(A(ctx[f"x_{N}"]) ** 2)
    return ctx, pb, A
