"""Regularized Momentum primal-dual Projected Gradient algorithm (RMPDPG)

Zhang, Hui-Li et al.
“Zeroth-Order primal-dual Alternating Projection Gradient Algorithms for Nonconvex Minimax Problems with Coupled linear Constraints.”
ArXiv abs/2402.03352.
"""

from __future__ import annotations

from typing import Optional, Callable
from dataclasses import dataclass
import numpy as np

from toy_example.problems import ToyExampleSol
from toy_example.common import GradEstimatorX, GradEstimatorY, ConstraintFn, Projection, Array, BasicAlgoConfig
from toy_example.algorithms.log_utils import BaseLogger


# ---- Config and state -----------------
@dataclass(slots=True)
class AlgoConfig(BasicAlgoConfig):
    max_iters: int = 1000
    log_every: int = 100
    seed: int | None = 42

    # alpha: float | Callable[[int], float] = lambda k: 0.001 / (k+2) ** (4/13)                   # x stepsize
    # beta: float | Callable[[int], float] = 0.015                                                # y stepsize
    # eta_x: float | Callable[[int], float] = lambda k: 1.0 / (k+2) ** (12/13)                    # momentum blend for x (STORM)
    # eta_y: float | Callable[[int], float] = lambda k: 2.0 / (k+2) ** (8/13)                     # momentum blend for y (STORM)
    # prox: float | Callable[[int], float] = 1e-4                                                 # prox parameter for y
    # intpl: float | Callable[[int], float] = lambda k: 10 / (10 + (k+2) ** (5/13))               # interpolation parameter
    # dual_stepsize: float | Callable[[int], float] = lambda k: 0.01 / (10 + (k+2) ** (4/13))     # dual stepsize

    alpha: float | Callable[[int], float] = 1e-3    # x step
    beta: float | Callable[[int], float] = 1e-2     # y step
    eta_x: float | Callable[[int], float] = 0.8     # momentum blend for x (STORM)
    eta_y: float | Callable[[int], float] = 0.8     # momentum blend for y (STORM)
    prox: float | Callable[[int], float] = 0.0001   # prox parameter for y
    intpl: float | Callable[[int], float] = 0.8     # interpolation parameter
    dual_stepsize: float | Callable[[int], float] = 1.0 # dual step

    is_storm: bool = True  # whether to use STORM-style momentum

    def alpha_at(self, k: int) -> float: return self.alpha(k) if callable(self.alpha) else float(self.alpha)
    def beta_at(self, k: int) -> float: return self.beta(k) if callable(self.beta) else float(self.beta)
    def eta_x_at(self, k: int) -> float: return self.eta_x(k) if callable(self.eta_x) else float(self.eta_x)
    def eta_y_at(self, k: int) -> float: return self.eta_y(k) if callable(self.eta_y) else float(self.eta_y)
    def prox_at(self, k: int) -> float: return self.prox(k) if callable(self.prox) else float(self.prox)
    def intpl_at(self, k: int) -> float: return self.intpl(k) if callable(self.intpl) else float(self.intpl)
    def dual_stepsize_at(self, k: int) -> float: return self.dual_stepsize(k) if callable(self.dual_stepsize) else float(self.dual_stepsize)


@dataclass(slots=True)
class AlgoState:
    x: Array
    y: Array
    lam: Array
    iter: int = 1
    x_prev: Array | None = None # previous x, for STORM momentum
    y_prev: Array | None = None # previous y, for STORM momentum
    gx_prev: Array | None = None # previous gradient of x, for STORM momentum
    gy_prev: Array | None = None # previous gradient of y, for STORM momentum


# ---- Algorithm implementation ---------
class RMPDPG:
    """
    Regularized Momentum primal-dual Projected Gradient (RMPDPG)
    """

    def __init__(
        self,
        *,
        cfg: AlgoConfig,
        proj_x: Projection,
        proj_y: Projection,
        constraint: ConstraintFn,
        grad_y: GradEstimatorY,
        grad_x: GradEstimatorX,
        grad_cx: Callable[[Array, Array], Array],
        grad_cy: Callable[[Array, Array], Array],
        noise_sampler: Callable[[Optional[int]], Array],
        objective_fn: Callable[[Array, Array], float],
        exact_sol: Optional[ToyExampleSol] = None,
    ) -> None:
        self.cfg = cfg
        self.proj_x = proj_x
        self.proj_y = proj_y
        self.constraint = constraint
        self.grad_y = grad_y
        self.grad_x = grad_x
        self.grad_cx = grad_cx
        self.grad_cy = grad_cy
        self.noise_sampler = noise_sampler
        self.objective_fn = objective_fn
        self.exact_sol = exact_sol
        self.logger = BaseLogger(objective_fn, constraint, exact_sol)

    def run(self, init: ToyExampleSol):
        # Initialize state
        init_lam = np.zeros_like(self.constraint(init.x, init.y))
        state = AlgoState(init.x.copy(), init.y.copy(), init_lam)
        self._log(0, state) # log initial state

        for k in range(1, self.cfg.max_iters + 1):
            state.iter = k
            state = self._step(state)

            if self.cfg.log_every and k % self.cfg.log_every == 0:
                self._log(k, state)
        return state

    def _step(self, state: AlgoState) -> AlgoState:
        
        # Get hyperparameters ...
        k = state.iter
        alpha, beta, eta_x, eta_y, prox = self.cfg.alpha_at(k), self.cfg.beta_at(k), self.cfg.eta_x_at(k), self.cfg.eta_y_at(k), self.cfg.prox_at(k)
        intpl, dual_stepsize = self.cfg.intpl_at(k), self.cfg.dual_stepsize_at(k)
        x_prev, y_prev = state.x_prev, state.y_prev

        # Compute gradients
        y_samples = None # warning: current self.grad_y() does not need y_samples, just set it to 1 to speed up
        org_gy = self._compute_grad_y(state.x, state.y, state.lam, prox=prox, samples=y_samples)
        if state.gy_prev is None or not self.cfg.is_storm:
            gy = org_gy
        else:
            assert x_prev is not None and y_prev is not None, "x_prev, y_prev must be set if dx_prev is set"
            gy_prev_corr = self._compute_grad_y(x_prev, y_prev, state.lam, prox=prox, samples=None)
            gy = org_gy + (1.0 - eta_y) * (state.gy_prev - gy_prev_corr)

        x_samples = self.noise_sampler(None)
        org_gx = self._compute_grad_x(state.x, state.y, state.lam, samples=x_samples)
        if state.gx_prev is None or not self.cfg.is_storm:
            gx = org_gx
        else:
            assert x_prev is not None and y_prev is not None, "x_prev, y_prev must be set if dx_prev is set"
            gx_prev_corr = self._compute_grad_x(x_prev, y_prev, state.lam, samples=x_samples)
            gx = org_gx + (1.0 - eta_x) * (state.gx_prev - gx_prev_corr)

        # Primal Update
        x_new = self.proj_x(state.x - alpha * gx)
        x_new = state.x + intpl * (x_new - state.x)
        y_new = self.proj_y(state.y + beta * gy)
        y_new = state.y + intpl * (y_new - state.y)
        
        # Dual Update
        con = self.constraint(x_new, y_new)
        lam_new = np.clip(state.lam + dual_stepsize * con, 0, None)
        lam_new = state.lam + intpl * (lam_new - state.lam)

        # Next state
        algo_state = AlgoState(
            x=x_new, y=y_new, lam=lam_new, iter=state.iter + 1,
            x_prev=state.x, y_prev=state.y,
            gx_prev=gx, gy_prev=gy,
        )

        return algo_state

    def _compute_grad_x(self, x: Array, y: Array, lam: Array, *, samples: Optional[Array] = None) -> Array:
        grad = self.grad_x(x, y, samples=samples)
        grad -= lam @ self.grad_cx(x, y)
        return grad

    def _compute_grad_y(self, x: Array, y: Array, lam: Array, *, prox: float, samples: Optional[Array] = None) -> Array:
        # objective gradient
        grad_y = self.grad_y(x, y, samples=samples)

        # constraint gradient
        c_grad_y = self.grad_cy(x, y)
        grad_y -= lam @ c_grad_y

        # proximal term
        grad_y -= prox * y

        return grad_y

    def _log(self, step: int, state: AlgoState) -> None:
        lam_norm = np.linalg.norm(state.lam)
        
        extra_metrics = {
            "metrics/lam_norm": lam_norm,
        }
        
        extra_log_items = [
            f"||lam||={lam_norm:.2e}",
        ]
        
        self.logger.log(step, state, extra_metrics, extra_log_items)

        if state.iter >= self.cfg.max_iters - 5 * self.cfg.log_every:
            with np.printoptions(precision=2, suppress=False):
                print("x:", state.x)
                print("y:", state.y)
                print("lam:", state.lam)
