"""
dual_feg setup module for pep_runner.py.

Algorithm: Dual Fast Extragradient with fixed step size alpha = 1 / L.
Performance metric: ||A(x_N)||^2.
Initial condition: ||x_0 - x_star||^2 <= R^2, where A(x_star) = 0.
Conjectured rate: unknown.
"""

import itertools

import attrs

import pepflow as pf

L = pf.Parameter("L")
alpha = 1 / L


@attrs.frozen(kw_only=True, repr=False)
class LipschitzMonotoneOperator(pf.MonotoneOperator):
    """Monotone operator with an L-Lipschitz interpolation constraint."""

    L: object

    def __hash__(self):
        return super().__hash__()

    def lipschitz_inequality_constraints(self, duplet_i, duplet_j):
        return (
            (duplet_i.output - duplet_j.output) ** 2
            - (self.L**2) * (duplet_i.point - duplet_j.point) ** 2
        ).le(
            0,
            name=f"{self.tag}_Lipschitz:{duplet_i.point.tag},{duplet_j.point.tag}",
        )

    def get_interpolation_constraints_by_group(self, pep_context=None):
        cd = super().get_interpolation_constraints_by_group(pep_context)
        if pep_context is None:
            pep_context = pf.get_current_context()
        if pep_context is None:
            raise RuntimeError("Did you forget to create a context?")

        ordered_duplets = [
            pep_context.get_duplet_by_point_tag(point.tag, self)
            for point in pep_context.tracked_point(self)
        ]
        lipschitz_constraints = [
            self.lipschitz_inequality_constraints(i, j)
            for i, j in itertools.combinations(ordered_duplets, 2)
        ]
        cd.add_sc_constraint("Lipschitz Operator Inequality", lipschitz_constraints)
        return cd


A = LipschitzMonotoneOperator(is_basis=True, tags=["A"], L=L)


def make_ctx_dual_feg(ctx_name: str, N, params) -> pf.PEPContext:
    """Build the PEPContext encoding N steps of Dual FEG."""
    del params

    ctx = pf.PEPContext(ctx_name).set_as_current()
    x = pf.Vector(is_basis=True, tags=["x_0"])
    z = pf.Vector.zero()
    z.add_tag("z_0")
    A.set_zero_point("x_star")

    for i in range(int(N)):
        rho = (N - i - 1) / (N - i)
        eta = 1 / (N - i)

        Ax = A(x)
        x_half = x - alpha * z - alpha * Ax
        x_half.add_tag(f"x_{i + 1}_half")

        A_x_half = A(x_half)
        x = x_half - rho * alpha * (A_x_half - Ax)
        x.add_tag(f"x_{i + 1}")

        z = rho * z - eta * A_x_half
        z.add_tag(f"z_{i + 1}")

    return ctx


def get_pep_setup(N, params):
    """Standard interface for pep_runner.py."""
    R = pf.Parameter("R")
    ctx = make_ctx_dual_feg(f"ctx_{N}", N, params)
    pb = pf.PEPBuilder(ctx)
    pb.add_initial_constraint(
        ((ctx["x_0"] - ctx["x_star"]) ** 2).le(R**2, name="initial_condition")
    )
    pb.set_performance_metric(A(ctx[f"x_{N}"]) ** 2)
    return ctx, pb, A
